# Plug-and-Play Image Restoration

<a target="_blank" href="https://colab.research.google.com/github/generativemodelingmva/generativemodelingmva.github.io/blob/main/tp2324/tp8_pnp.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

<br/><br/><br/>

This practical session is dedicated to the implementation of plug-and-play algorithms with pre-learned denoisers.

In [None]:
import numpy as np
import torch
from torch.fft import fft2, ifft2, fftshift, ifftshift
import matplotlib.pyplot as plt
import time    

import torch
print(torch.__version__)
pi = torch.pi
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

!pip install bm3d
# !pip install deepinv
# or last version of deepinv:
!pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv
import deepinv as dinv

# Uncomment these two lines to download the files for this session
!wget https://perso.telecom-paristech.fr/aleclaire/mva/tp8.zip
!unzip tp8.zip

In [None]:
def rgb2gray(u):
    return 0.2989 * u[:,:,0] + 0.5870 * u[:,:,1] + 0.1140 * u[:,:,2]

def str2(chars):
    return "{:.2f}".format(chars)

def psnr(uref,ut,M=1):
    rmse = np.sqrt(np.mean((np.array(uref.cpu())-np.array(ut.cpu()))**2))
    return 20*np.log10(M/rmse)

def tensor2im(x):
    return x.detach().cpu().permute(2,3,1,0).squeeze().clip(0,1)

# viewimage
import tempfile
import IPython
from skimage.transform import rescale

def viewimage(im, normalize=True,vmin=0,vmax=1,z=2,order=0,titre='',displayfilename=False):
    # By default, values are scaled with black=0 and white=1
    # In order to adapt the dynamics to the image, enter vmin and vmax as None
    im = im.detach().cpu().permute(2,3,1,0).squeeze()
    imin= np.array(im).astype(np.float32)
    channel_axis = 2 if len(im.shape)>2 else None
    imin = rescale(imin, z, order=order, channel_axis=channel_axis)
    if normalize:
        if vmin is None:
            vmin = imin.min()
        if vmax is None:
            vmax = imin.max()
        imin-=vmin
        if np.abs(vmax-vmin)>1e-10:
            imin = (imin.clip(vmin,vmax)-vmin)/(vmax-vmin)
        else:
            imin = vmin
    else:
        imin=imin.clip(0,255)/255 
    imin=(imin*255).astype(np.uint8)
    filename=tempfile.mktemp(titre+'.png')
    if displayfilename:
        print (filename)
    plt.imsave(filename, imin, cmap='gray')
    IPython.display.display(IPython.display.Image(filename))


# alternative viewimage if the other one does not work:
def Viewimage(im,dpi=100,cmap='gray'):
    plt.figure(dpi=dpi)
    if cmap is None:
        plt.imshow(np.array(im))
    else:
        plt.imshow(np.array(im),cmap=cmap)
    plt.axis('off')
    plt.show()

<br/><br/><br/>

# Exercise 1: Plug-and-Play Image Deblurring (with periodic convolution)

In [None]:
# Open the image
x0 = torch.tensor(plt.imread('im/simpson512crop.png'),device=device)
# x0 = torch.tensor(plt.imread('im/parrots.png'),device=device); x0 = x0[100:356,370:626,:]
# x0 = torch.tensor(plt.imread('im/marge2.png'),device=device)
# x0 = torch.tensor(plt.imread('im/simpson512.png'),device=device)
M,N,C = x0.shape
# Permute dimensions to fit tensor convention
x0 = x0.permute(2,0,1).unsqueeze(0)

viewimage(x0)

In [None]:
# Load a blur kernel
kt = torch.tensor(np.loadtxt('kernels/kernel8.txt'))
# kt = np.loadtxt('kernels/levin7.txt')
(m,n) = kt.shape

plt.imshow(kt)
plt.title('Blur kernel')
plt.show()

# Embed the kernel in a MxNx3 image, and put center at pixel (0,0)
k = torch.zeros((M,N),device=device)
k[0:m,0:n] = kt/torch.sum(kt)
k = torch.roll(k,(-int(m/2),-int(n/2)),(0,1))
#k = k[:,:,None].repeat(1,1,3)
#k = k.permute(2,0,1).unsqueeze(0)
k = k[None,None,:,:]
fk = fft2(k)

viewimage(fftshift(k),vmin=None,vmax=None)


## Test a pre-learned denoiser

Compute a noisy image
$$ y = x_0 + \xi $$
where $\xi \sim \mathcal{N}(0,\nu^2 \mathsf{Id})$.
Denoise the image $y$ by using a pre-learned denoiser.

In [None]:
nu = 2/255 # 0.15
y = x0 + nu*torch.randn_like(x0,device=device)

# Load the DRUNet denoiser
# https://deepinv.github.io/deepinv/stubs/deepinv.models.DRUNet.html
# D = dinv.models.DRUNet(pretrained='ckpts/drunet_color.pth').to(device)

# Load the BM3D denoiser
# https://deepinv.github.io/deepinv/stubs/deepinv.models.BM3D.html
# D = dinv.models.BM3D().to(device)

# Load the DnCNN denoiser (WARNING: the proposed weights are only trained for noise level sigma = 2/255)
# https://deepinv.github.io/deepinv/stubs/deepinv.models.DnCNN.html
D = dinv.models.DnCNN(pretrained='ckpts/dncnn_sigma2_color.pth').to(device)
# D = dinv.models.DnCNN(pretrained='ckpts/dncnn_sigma2_lipschitz_color.pth').to(device)

# TV denoiser (only in last version of deepinv)
# Dtv = dinv.models.TVDenoiser().to(device)
# def D(x,sigma):
#   return Dtv(x,ths=2*sigma**2)

# noisy image
print('PSNR(x0,y) = %.2f'%psnr(x0,y))
viewimage(y)
# denoise image
Dy = D(y,sigma=nu)
print('PSNR(x0,Dy) = %.2f'%psnr(x0,Dy))
viewimage(Dy)


## Image deblurring with PnP-PGD 

Implement the forward model
$$ y = A(x_0) + \xi $$
where $\xi \sim \mathcal{N}(0,\nu^2 \mathsf{Id})$.
Write functions implementing the operator $A(x)$ and the data-fidelity term $f(x)$.

In [None]:
nu = .01  # noise level
torch.manual_seed(1)  # fix random seed for reproducibility

# Define corresponding operator and data-fidelity
def A(x):
    return ### TODO ###

def f(x):
    return ### TODO ###

# Draw a sample of the direct model for image deblurring (apply blur and add Gaussian noise)
### TODO ###

viewimage(y)

In this question, we will perform deblurring with the PnP-PGD algorithm 
$$ x_{k+1} = D_\sigma \circ (\operatorname{Id} - \tau \nabla f) (x_k) $$
where $f(x) = \frac{1}{2\nu^2} \|Ax-y\|_2^2$ is the data-fidelity term.

Recall that $\tau$ should be $< \frac{2}{L}$ where $L$ is the Lipschitz constant of $\nabla f$.

Complete the following cell progressively, in order to address the following points:
1. Implement the PnP-PGD algorithm, and display the deblurred image.
2. Track the evolution of the PSNR.
3. Track the evolution of the residual $r_n = \frac{\|x_n - x_{n-1}\|}{\|x_0\|}$.
4. Try to adjust the parameter $\tau$ (gradient step size / strength of data-fidelity).
5. Try to adjust the parameter $s$ (strength of the denoiser).
6. Track the evolution of $v_n = \frac{\|D_\sigma(a_n) - D_\sigma(b_{n})\|}{\|a_n - b_n\|}$ (which lower bounds the Lipschitz constant of $D_\sigma$). <br/>
  You may choose sequences $(a_n), (b_n)$ that depend on the last iterates without additional evaluations of $D_\sigma$.
7. Store the PSNR/Residual tables and compare with several denoisers.

In [None]:
tau = ###
s = 2*nu  # strength of the denoiser (that is, sigma)

# initialize with blurry image
x = y.clone()

psnrtab = []  # to store psnr
rtab = []     # to store residual
vtab = []     # to store denoiser variations

niter = 100
t0 = time.time()
print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(0,niter,0,psnr(x0,y)))

for it in range(niter):
    
    ### TODO ###
    
    if (it+1)%10==0:
        print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(it+1,niter,time.time()-t0,psnrt))
        viewimage(x)
        
xpgd = x   # save for later comparisons

plt.plot(psnrtab)
plt.title('PSNR')
plt.show()

# plt.semilogy(rtab)
# plt.title('Residual Norm')
# plt.show()

# plt.plot(vtab)
# plt.title('Denoiser Variations')
# plt.show()

In [None]:
# Save the tables obtained with various denoisers for later comparisons

# psnrtabtmp = psnrtab.copy()
# rtabtmp = rtab.copy()
# vtabtmp = vtab.copy()
# rtab_bm3d = rtab.copy()
# psnrtab_bm3d = psnrtab.copy()
# vtab_bm3d = vtab.copy()
# rtab_drunet = rtab.copy()
# psnrtab_drunet = psnrtab.copy()
# vtab_drunet = vtab.copy()
# rtab_dncnn = rtab.copy()
# psnrtab_dncnn = psnrtab.copy()
# vtab_dncnn = vtab.copy()
# rtab_dncnnlip = rtab.copy()
# psnrtab_dncnnlip = psnrtab.copy()
# vtab_dncnnlip = vtab.copy()

Compare with explicit regularizations (Tychonov, smoothTV).

In [None]:
# Deblurring with Tychonov Regularization
def tych_deblur(y,k,lam=0.01):
    _,_,M,N = y.shape
    xi = torch.arange(M)
    ind = (xi>M/2)
    xi[ind] = xi[ind]-M
    zeta = torch.arange(N)
    ind = (zeta>N/2)
    zeta[ind] = zeta[ind]-N
    Xi,Zeta = torch.meshgrid(xi,zeta,indexing='ij')
    Xi = Xi[None,None,:,:].to(device)
    Zeta = Zeta[None,None,:,:].to(device)
    fh = torch.conj(fk)/(torch.abs(fk)**2 + 8 * lam * (torch.sin(pi*Xi/M)**2 + torch.sin(pi*Zeta/N)**2))
    return ifft2(fft2(y)*fh).real

xtych = tych_deblur(y,k)

# Smooth TV regularization
def stv_deblur(A,y,xinit,niter=1000,lam=0.002,ep=0.01,lr=None):
    if lr is None:
        lr = 1.9/(1+lam*8/ep)
    x = xinit.clone().requires_grad_(True)
    optim = torch.optim.SGD([x], lr=lr)
    losslist = []
    for it in range(niter):
        d1 = torch.roll(x,-1,2) - x
        d2 = torch.roll(x,-1,3) - x
        reg = torch.sum(torch.sqrt(ep**2+d1**2+d2**2))
        loss = torch.sum((A(x)-y)**2) + lam*reg
        losslist.append(loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()
    return x.detach(),losslist

xtv,losslist = stv_deblur(A,y,y.clone())
# plt.plot(losslist)
# plt.show()

# Display the results
print('PSNR(x0,xtych) = %.2f'%psnr(x0,xtych))
viewimage(xtych)
print('PSNR(x0,xtv) = %.2f'%psnr(x0,xtv))
viewimage(xtv)

Compare results obtained with PnP-PGD and with explicit regularizers.

In [None]:
plt.figure(dpi=180)
plt.subplot(2,2,1)
plt.imshow(tensor2im(x0), cmap='gray')
plt.title('Original',fontsize=8)
# plt.imshow(tensor2im(y), cmap='gray')
# plt.title('Degraded \n PSNR='+str2(psnr(x0,y)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(tensor2im(xtych), cmap='gray')
plt.title('Tychonov \n PSNR='+str2(psnr(x0,xtych)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(tensor2im(xtv), cmap='gray')
plt.title('SmoothTV \n PSNR='+str2(psnr(x0,xtv)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(tensor2im(xpgd), cmap='gray')
plt.title('PnP-PGD \n PSNR='+str2(psnr(x0,xpgd)),fontsize=8)
plt.axis('off')
plt.show()

Compare Residual Norms, PSNR, and Denoiser variations for various denoisers.

In [None]:
plt.figure(dpi=150)
plt.semilogy(rtab_dncnn,label='DnCNN')
plt.semilogy(rtab_dncnnlip,label='DnCNNLip')
plt.semilogy(rtab_drunet,label='DRUNet')
plt.semilogy(rtab_bm3d,label='BM3D')
plt.legend()
plt.title('Residual Norm')
plt.show()

plt.figure(dpi=150)
plt.plot(psnrtab_dncnn,label='DnCNN')
plt.plot(psnrtab_dncnnlip,label='DnCNNLip')
plt.plot(psnrtab_drunet,label='DRUNet')
plt.plot(psnrtab_bm3d,label='BM3D')
plt.plot([psnr(x0,xtv)]*niter,label='smoothTV')
plt.plot([psnr(x0,xtych)]*niter,label='Tychonov')
plt.legend()
plt.title('PSNR')
plt.show()

plt.figure(dpi=150)
plt.plot(vtab_dncnn,label='DnCNN')
plt.plot(vtab_dncnnlip,label='DnCNNLip')
plt.plot(vtab_drunet,label='DRUNet')
plt.plot(vtab_bm3d,label='BM3D')
plt.legend()
plt.title('Denoiser Variations')
plt.show()

<br/><br/>

## Image deblurring with PnP-HQS

Implement the proximal operator of the data-fidelity term:
$$\mathsf{Prox}_{\tau f}(x) = \left( \frac{1}{\nu^2} A^T A + \frac{1}{\tau} \mathsf{Id} \right)^{-1} \left( \frac{1}{\nu^2} A^T y + \frac{1}{\tau} x \right) .$$
Since $A$ is here a periodic convolution, this calculation can be done in Fourier domain.

In [None]:
def proxf(x,tau):
    ### TODO ###

Implement the PnP-HQS algorithm 
$$ x_{k+1} = D_\sigma \circ \mathsf{Prox}_{\tau f} (x_k) .$$

In [None]:
### TODO ###

## Image deblurring with PnP-DRS

Implement the PnP-DRS algorithm 
$$ x_{k+1} = \left(\frac{1}{2} \mathsf{Id} + \frac{1}{2} (2 D_\sigma - \mathsf{Id}) \circ (2\mathsf{Prox}_{\tau f}-\mathsf{Id})\right) (x_k) .$$
Recall that the solution of the inverse problem is obtained after one proximal step
$$ \tilde{x}_k = \mathsf{Prox}_{\tau f} (x_k) .$$

In [None]:
### TODO ###

<br/><br/><br/>

# Exercise 2: Image deblurring with non-periodic boundary conditions

Implement a PnP algorithm that addresses image deblurring with non-periodic boundary conditions.

You should adapt the codes written in the previous cells for this new forward model. We advise you to make a copy of the whole notebook and to make the adaptation in a separate file.

Which PnP splitting method can you use for this particular setting?

<br/><br/><br/>

# Exercise 3: Image super-resolution

Implement a PnP algorithm that addresses image deblurring with non-periodic boundary conditions.

You should adapt the codes written in the previous cells for super-resolution. The forward model for super-resolution involves an anti-aliasing filter whose Fourier transform is given in the next cell.

We advise you to make a copy of the whole notebook and to make the adaptation in a separate file.

In [None]:
# Adjust the framework to address super-resolution with smoothed TV
# For anti-aliasing, you may use the Butterworth filter of order n and cut-off frequency fc 
#   given below

# fc is the cut-off frequency normalized in (0,1)
def butterworth(M,N,fc=.5,order=5):
    xi = torch.arange(M)
    ind = (xi>M/2)
    xi[ind] = xi[ind]-M
    zeta = torch.arange(N)
    ind = (zeta>N/2)
    zeta[ind] = zeta[ind]-N
    Xi,Zeta = torch.meshgrid(xi,zeta,indexing='ij')
    Xi = Xi[None,None,:,:].to(device)
    Zeta = Zeta[None,None,:,:].to(device)
    bf1 = 1/torch.sqrt(1+(Xi/(M*fc/2))**(2*order))
    bf2 = 1/torch.sqrt(1+(Zeta/(N*fc/2))**(2*order))
    return bf1*bf2
    
bf = butterworth(M,N)
viewimage(bf)

# Use example:
bf = butterworth(x0.shape[2],x0.shape[3],fc=.5)
x0f = ifft2(bf*fft2(x0)).real

viewimage(x0)
viewimage(x0f)