# Image Denoising

Acquired images are corrupted by noise. Noise are random variations that are due
the intrinsic quantic nature of light and thermal agitation of electrons in the 
circuits. 

Denoising is a fundamental problem in image processing as it requires to
model the nature of the image itself and therefore provided a good experimental
validation. Over the years many approaches have been proposed for
denoising images based on various assumption on the image.

Experimental validation:




T. Plotz and S. Roth, ‘Benchmarking Denoising Algorithms With Real Photographs’, presented at the Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017, pp. 1586–1595. Available: https://openaccess.thecvf.com/content_cvpr_2017/html/Plotz_Benchmarking_Denoising_Algorithms_CVPR_2017_paper.html



## BM3D

BM3D is one of the most performant approach for method not actively based on 
learning.

Y. Mäkinen, L. Azzari, A. Foi, 2020, "Collaborative Filtering of Correlated Noise: Exact Transform-Domain Variance for Improved Shrinkage and Patch Matching", in IEEE Transactions on Image Processing, vol. 29, pp. 8339-8354.

We can install BM3D from pip
```
    pip install bm3d
```


In [None]:
# Run this cell is your are developing the code locally
%load_ext autoreload
%autoreload 2
import sys
import site
site.addsitedir('../') 

Let's apply the denoising method on a test image

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from bm3d import bm3d
import skimage

noise_std = 20
im = skimage.data.astronaut()[:,:,0]
noise = np.random.normal(0, noise_std, im.shape)
noisy = im.copy() + noise
denoised = bm3d(noisy, noise_std)

# display the result with quality metric
fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(im, cmap='gray')
ax[0].set_title('Original image')
ax[1].imshow(noisy, cmap='gray')
ax[1].set_title(f'Noisy image {skimage.metrics.peak_signal_noise_ratio(im, noisy, data_range=255):.2f}dB')
ax[2].imshow(denoised, cmap='gray')
ax[2].set_title(f'Denoised image {skimage.metrics.peak_signal_noise_ratio(im, denoised, data_range=255):.2f}dB')
for a in ax:
    a.set_axis_off()

Benchmarking

In [None]:
from mug import denoising
result_bm3d = denoising.benchmark(bm3d, '../data/Set12/*.png', blind=False, noise_levels=range(5,40,5))
result_bm3d.pivot_table('psnr','noise level','image name').plot()
plt.ylabel('PSNR')
plt.title('Set12 Benchmark for BM3D')

## DnCNN
In DnCNN the task is set to predict the the residual image between the noisy and
the noise free image. The network is composed of a deep (17 stage) sequence of 
blocks composed of convolution, batch normalization and RELU. The training is 
done by cropping 50x50 patches of images.

A supervised with known noise level and a blind version can be trained.

Zhang, Kai, et al. "Beyond a gaussian denoiser: Residual learning of deep cnn for 
image denoising." IEEE transactions on image processing 26.7 (2017): 3142-3155.

In [None]:
dncnn_model = denoising.DnCNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dncnn_model.to(device)
train_dataset = denoising.ImageFolderDataset('../data/Set12/*', denoising.DnCNNAugmenter([64,64], [10,30]),1000)
train_dl = DataLoader(train_dataset, batch_size=10)
optimizer = optim.AdamW(dncnn_model.parameters(), lr=1e-6)
train_dl = DataLoader(train_dataset, batch_size=12)
loss_fn = nn.MSELoss()
lropt, h = lr_finder(dncnn_model,optimizer,loss_fn,train_dl,device,1e-6,10)
plt.loglog(h['learning rate'],h['loss'])


In [None]:
from mug import denoising
from mug import utils
import torch
from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt

#dncnn_model = denoising.DnCNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dncnn_model.to(device)
train_dataset = denoising.ImageFolderDataset('../data/Set12/*', denoising.DnCNNAugmenter([64,64], [10,30]))
valid_dataset = denoising.ImageFolderDataset('../data/Set12/*', denoising.DnCNNAugmenter([64,64], [10,30]))
train_dl = DataLoader(train_dataset, batch_size=4)
valid_dl = DataLoader(valid_dataset, batch_size=4)
num_epochs = 100
optimizer = optim.AdamW(dncnn_model.parameters(), lr = 3e-4)
#lropt, h0 = denoising.lr_finder(dncnn_model,optimizer,loss_fn,train_dl,device,1e-6,10)
#print(f'Optimal learning rate {lropt}')
optimizer = optim.AdamW(dncnn_model.parameters(), lr = lropt)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dl), epochs=num_epochs)
loss_fn = nn.MSELoss()
history = utils.train_network(dncnn_model,optimizer,loss_fn,train_dl,valid_dl,num_epochs,device)
plt.plot(history['train_epoch'],history['train_loss'])
plt.plot(history['valid_epoch'],history['valid_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [None]:
from mug import denoising
dncnn = denoising.DnCNNDenoiser(dncnn_model)
result_dncnn = denoising.benchmark(dncnn, '../data/Set12/*.png', blind=True, noise_levels=range(5,40,5))
result_dncnn.pivot_table('psnr','noise level','image name').plot()
plt.ylabel('PSNR')
plt.title('Set12 Benchmark for DnCNN')

In [None]:
result_bm3d.pivot_table('psnr','noise level','image name').agg('mean',axis=1).plot()
result_dncnn.pivot_table('psnr','noise level','image name').agg('mean',axis=1).plot()
plt.legend(['bm3d','dncnn'])
plt.ylabel('PSNR')
plt.title('Average Set12 Benchmark')

In [None]:
import skimage
import numpy as np
from mug import denoising

def psnr(x, y, vmax = 255):
    """Peak signal to noise ratio quality metric"""
    import math
    return 10 * math.log10(vmax * vmax / np.mean(np.square(x - y)))
    
noise_std = 10
im = skimage.data.astronaut()[:,:,0]
noise = np.random.normal(0, noise_std, im.shape)
noisy = im.copy() + noise

dncnn = denoising.DnCNNDenoiser(dncnn_model)
denoised = dncnn(noisy)

# display the result with quality metric
fig, ax = plt.subplots(1,3,figsize=(15,5))
ax[0].imshow(im, cmap='gray')
ax[0].set_title('Original image')
ax[1].imshow(noisy, cmap='gray')
ax[1].set_title(f'Noisy image {psnr(noisy, im):.2f}dB')
ax[2].imshow(denoised, cmap='gray')
ax[2].set_title(f'Denoised image {psnr(denoised, im):.2f}dB')
for a in ax:
    a.set_axis_off()

In [None]:
drunet_model = denoising.DRUNET()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
drunet_model.to(device)
train_dataset = denoising.ImageFolderDataset('../data/Set12/*', denoising.DRUNETAugmenter([64,64], [10,30]))
valid_dataset = denoising.ImageFolderDataset('../data/Set12/*', denoising.DRUNETAugmenter([64,64], [10,30]))
train_dl = DataLoader(train_dataset, batch_size=2)
valid_dl = DataLoader(valid_dataset, batch_size=2)
num_epochs = 100
optimizer = optim.AdamW(drunet_model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(train_dl), epochs=num_epochs)
loss_fn = nn.MSELoss()
history = utils.train_network(drunet_model,optimizer,loss_fn,train_dl,valid_dl,num_epochs,device)
plt.plot(history['train_epoch'],history['train_loss'])
plt.plot(history['valid_epoch'],history['valid_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [None]:
x,y = train_dataset[0]
with torch.no_grad():
    yhat = drunet_model(x.reshape([1,*x.shape]))
fig, ax = plt.subplots(1,4)
ax[0].imshow(x[0].cpu().numpy().squeeze())
ax[1].imshow(x[1].cpu().numpy().squeeze())
ax[2].imshow(y.cpu().numpy().squeeze())
ax[3].imshow(yhat.cpu().numpy().squeeze())
print(x[1].mean(),(x[0]-y).mean(),(yhat-y).std())

In [None]:
from mug import denoising
drunet = denoising.DRUNETDenoiser(drunet_model)
result_drunet = denoising.benchmark(drunet, '../data/Set12/*.png', blind=False, noise_levels=range(5,40,2))
result_drunet.pivot_table('psnr','noise level','image name').plot()
plt.ylabel('PSNR')
plt.title('Set12 Benchmark for DRUNET')

In [None]:
from skimage.metrics import peak_signal_noise_ratio
drunet = denoising.DRUNETDenoiser(model)
im = skimage.data.astronaut()[:,:,0].astype(float)
noisy = im + np.random.normal(0,10, size=im.shape)
estimate = drunet(im,10.)
plt.subplot(121).imshow(noisy,cmap='gray')
plt.subplot(122).imshow(estimate,cmap='gray')
psnr = peak_signal_noise_ratio(im, estimate, data_range=255)
plt.title(f'PSNR {psnr:.2f}dB')


In [None]:
result_bm3d.pivot_table('psnr','noise level','image name').agg('mean',axis=1).plot()
result_dncnn.pivot_table('psnr','noise level','image name').agg('mean',axis=1).plot()
result_drunet.pivot_table('psnr','noise level','image name').agg('mean',axis=1).plot()
plt.legend(['bm3d','dncnn','drunet'])
plt.ylabel('PSNR')
plt.title('Average Set12 Benchmark for DnCNN')

## Poisson - Gaussian noise 
Poisson-Gaussian noise model is able to capture both the shot noise and the 
readout noise of the cameras.

For denoising image corrupted by a Poisson-Gaussian noise, a variance
stabilization step allows to convert the signal dependend noise to a signal
independent noise of variance 1.

In [None]:
im = skimage.data.astronaut()[:,:,0].astype(float)
im = skimage.filters.gaussian(im,4)
gain = 1
offset = 10
sigma = 1
noisy = gain * np.random.poisson(im) + offset + np.random.normal(0,sigma,size=im.shape)
gat = denoising.generalized_anscombe_transform(gain, offset, sigma)
# gat.calibrate(noisy)
variance_stabilized = gat(noisy)
x0, y0 = denoising.compute_local_statistics(noisy)
x1, y1 = denoising.compute_local_statistics(variance_stabilized)
fig, ax = plt.subplots(1,2,figsize=(15,5))
ax[0].scatter(x0.flatten()[::50], y0.flatten()[::50], alpha=0.1)
ax[0].set_title('Before stabilization')
ax[0].set_xlabel('Mean intensity')
ax[0].set_ylabel('Noise variance')
ax[1].scatter(x1.flatten()[::50], y1.flatten()[::50], alpha=0.1)
ax[1].set_title('After stabilization')
ax[1].set_xlabel('Mean intensity')
ax[1].set_ylabel('Noise variance')