# Plug-and-Play Image Restoration

<br/><br/><br/>

In this practical session, we will implement **plug-and-play algorithms with the gradient-step denoiser**.

In [None]:
import numpy as np
import torch
from torch.fft import fft2, ifft2, fftshift, ifftshift
import matplotlib.pyplot as plt
import time    

import torch
print(torch.__version__)
pi = torch.pi
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

!pip install pytorch-lightning
# !pip install bm3d
# !pip install deepinv
# or last version of deepinv:
!pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv
import deepinv as dinv

# Uncomment these two lines to download the files for this session
!wget https://perso.telecom-paristech.fr/aleclaire/mva/tp9.zip
!unzip tp9.zip

In [None]:
def rgb2gray(u):
    return 0.2989 * u[:,:,0] + 0.5870 * u[:,:,1] + 0.1140 * u[:,:,2]

def str2(chars):
    return "{:.2f}".format(chars)

def psnr(uref,ut,M=1):
    rmse = np.sqrt(np.mean((np.array(uref.cpu())-np.array(ut.cpu()))**2))
    return 20*np.log10(M/rmse)

def tensor2im(x):
    return x.detach().cpu().permute(2,3,1,0).squeeze().clip(0,1)

# viewimage
import tempfile
import IPython
from skimage.transform import rescale

def viewimage(im, normalize=True,vmin=0,vmax=1,z=2,order=0,titre='',displayfilename=False):
    # By default, values are scaled with black=0 and white=1
    # In order to adapt the dynamics to the image, enter vmin and vmax as None
    im = im.detach().cpu().permute(2,3,1,0).squeeze()
    imin= np.array(im).astype(np.float32)
    channel_axis = 2 if len(im.shape)>2 else None
    imin = rescale(imin, z, order=order, channel_axis=channel_axis)
    if normalize:
        if vmin is None:
            vmin = imin.min()
        if vmax is None:
            vmax = imin.max()
        imin-=vmin
        if np.abs(vmax-vmin)>1e-10:
            imin = (imin.clip(vmin,vmax)-vmin)/(vmax-vmin)
        else:
            imin = vmin
    else:
        imin=imin.clip(0,255)/255 
    imin=(imin*255).astype(np.uint8)
    filename=tempfile.mktemp(titre+'.png')
    if displayfilename:
        print (filename)
    plt.imsave(filename, imin, cmap='gray')
    IPython.display.display(IPython.display.Image(filename))


# alternative viewimage if the other one does not work:
def Viewimage(im,dpi=100,cmap='gray'):
    plt.figure(dpi=dpi)
    if cmap is None:
        plt.imshow(np.array(im))
    else:
        plt.imshow(np.array(im),cmap=cmap)
    plt.axis('off')
    plt.show()

<br/><br/><br/>

# Exercise 1: Plug-and-Play Image Deblurring (with periodic convolution)

First, we load a clean image and a blur kernel.

In [None]:
# Open the image
x0 = torch.tensor(plt.imread('im/simpson512crop.png'),device=device)
# x0 = torch.tensor(plt.imread('im/parrots.png'),device=device); x0 = x0[100:356,370:626,:]
# x0 = torch.tensor(plt.imread('im/marge2.png'),device=device)
# x0 = torch.tensor(plt.imread('im/simpson512.png'),device=device)
M,N,C = x0.shape
# Permute dimensions to fit tensor convention
x0 = x0.permute(2,0,1).unsqueeze(0).contiguous()

viewimage(x0)

# Load a blur kernel
kt = torch.tensor(np.loadtxt('kernels/kernel8.txt'))
# kt = np.loadtxt('kernels/levin7.txt')
(m,n) = kt.shape

plt.imshow(kt)
plt.title('Blur kernel')
plt.show()

# Embed the kernel in a MxNx3 image, and put center at pixel (0,0)
k = torch.zeros((M,N),device=device)
k[0:m,0:n] = kt/torch.sum(kt)
k = torch.roll(k,(-int(m/2),-int(n/2)),(0,1))
k = k[None,None,:,:]
fk = fft2(k)

## Forward Model

Implement the forward model
$$ y = A(x_0) + \xi $$
where $\xi \sim \mathcal{N}(0,\nu^2 \mathsf{Id})$.
Write functions implementing the operator $A(x)$, the data-fidelity term $f(x)$, and its proximal operator $\mathsf{Prox}_{\tau f}$.

In [None]:
nu = .01  # noise level
torch.manual_seed(1)  # fix random seed for reproducibility

# Define corresponding operator and data-fidelity

### TODO ####

# Draw a sample of the direct model for image deblurring (apply blur and add Gaussian noise)
### TODO ####

viewimage(y)

## Test the Gradient-Step Denoiser and the corresponding regularization

In the following, we will use several algorithms to minimize
$$F(x) = f(x) + \lambda g_{\sigma}(x)$$
where $g_{\sigma}$ is the regularization function linked to the gradient-step denoiser.

The function $g_{\sigma}$ is accessible through `D.potential`, and the corresponding gradient-step denoiser then writes
$$ D_{\sigma} = \mathsf{Id} - \nabla g_{\sigma} .$$
<br/><br/>

In the following cell, test the gradient-step denoiser on a noisy image `xnoisy`.

In [None]:
# load the gradient-step denoiser
D=dinv.models.GSDRUNet(pretrained='ckpts/GSDRUNet.ckpt', train=False).to(device)
# D=dinv.models.GSDRUNet(pretrained='ckpts/Prox-DRUNet.ckpt', train=False).to(device)


xnoisy = x0+nu*torch.randn_like(x0,device=device)
x = xnoisy.clone().requires_grad_(True)
px = D.potential(x,sigma=nu)
grad = torch.autograd.grad(px,x)[0]

Dx = xnoisy-grad

viewimage(Dx)

# # alternately, you can use
# Dx = D(xnoisy,sigma=nu)
# # but this causes a requires_grad problem on a previous version of deepinv.

## PnP-GD with Gradient-Step Denoiser

Implement the gradient descent algorithm 
$$ x_{k+1} = x_k - \tau \nabla F(x_k) .$$
Along the iterations, track the function values $F(x_k)$, the PSNR and the residual norm $r_k = \|x_{k}-x_{k-1}\|$.

In [None]:
# GSDRUNET
# https://deepinv.github.io/deepinv/stubs/deepinv.models.GSDRUNet.html
D=dinv.models.GSDRUNet(pretrained='ckpts/GSDRUNet.ckpt', train=False).to(device)

tau = 2e-4
s = 1.8*nu  # strength of the denoiser (sigma in the text)
lam = 1600

# Define the objective function
def F(x, lam):
    return ### TODO ###

# initialize
x = torch.clone(y).requires_grad_(True)
normxinit = torch.linalg.vector_norm(x.detach())

losstab = []
psnrtab = []
rtab = []
niter = 100
t0 = time.time()
print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(0,niter,0,psnr(x0,y)))



for it in range(niter):

    ### TODO ###
    
    losstab.append(loss.item())
    psnrtab.append(psnrt)
    rtab.append(r.item())
    
    if (it+1)%10==0:
        print('[%4d/%4d] [%.5f s] PSNR = %.2f, F(x) = %.6f, tau = %.6f'%(it+1,niter,time.time()-t0,psnrt,loss.item(),tau))
        viewimage(x)

        
plt.plot(losstab)
plt.title('F(x_k)')
plt.show()

plt.plot(psnrtab)
plt.title('PSNR')
plt.show()

plt.semilogy(rtab)
plt.title('Residual Norm')
plt.show()

# save results
x_gsgd = x.detach().clone()
losstab_gsgd = losstab.copy()
psnrtab_gsgd = psnrtab.copy()
rtab_gsgd = rtab.copy()

## PnP-PGD with Gradient-Step Denoiser And Backtracking

Implement the proximal gradient descent algorithm 
$$ x_{k+1} = \mathsf{Prox}_{\tau f}\big(x_k - \tau  \lambda \nabla g_{\sigma}(x_k) \big) .$$
Along the iterations, track the function values $F(x_k)$, the PSNR and the residual norm $r_k = \|x_{k}-x_{k-1}\|$.

Modify your code in order to include the backtracking procedure: at each iteration,
$$ \textbf{while} \quad  F(x_k) - F(T_\tau(x_k)) > \frac{\gamma}{\tau} \|T_\tau(x_k) - x_k\|^2 \quad \textbf{do} \quad \tau \leftarrow \eta \tau .$$
You may take $\gamma = 0.4, \eta = 0.9$.

Modify your code so that the output image $\tilde{x}_k$ (variable `xvisu`) is the one obtained before the last proximal step:
$$ x_k = \mathsf{Prox}_{\tau f}(\tilde{x}_k) .$$
You may display `xvisu` instead of `x` and track PSNR with `xvisu`. (In order to improve visual quality, it is useful not to apply the last data-fidelity step which adds back some noise.) <br/>
You may also display the value $F(\tilde{x}_k)$ which corresponds to values obtained by the HQS algorithm
$$ \tilde{x}_{k+1} = D_\sigma(\mathsf{Prox}_{\tau f}(\tilde{x}_k))$$
where $D_\sigma = \mathsf{Id} - \nabla g_\sigma$ is (by abuse) seen here as a proximal regularization step.

In [None]:
# GSDRUNET
# https://deepinv.github.io/deepinv/stubs/deepinv.models.GSDRUNet.html
D=dinv.models.GSDRUNet(pretrained='ckpts/GSDRUNet.ckpt', train=False).to(device)

# Parameters
tau = 2e-4
s = 1.8*nu  # strength of the denoiser (sigma in the text)
lam = 1600

# initialize
x = torch.clone(y).requires_grad_(True)
normxinit = torch.linalg.vector_norm(x.detach())
xvisuold = x.detach().clone()

losstab = []
psnrtab = []
rtab = []
# same tables for \tilde{x} :
losstab2 = []
psnrtab2 = []
rtab2 = []

niter = 100
t0 = time.time()
print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(0,niter,0,psnr(x0,y)))

for it in range(niter):

    
    ### TODO ###

    if (it+1)%10==0:
        # print('[%4d/%4d] [%.5f s] PSNR = %.2f, F(x) = %.6f, tau = %.6f'%(it+1,niter,time.time()-t0,psnrt2,Fxnew.item(),tau))
        viewimage(xvisu)

        
plt.plot(losstab,label='PGD')
plt.plot(losstab2,label='PGD2')
plt.title('F(x_n)')
plt.legend()
plt.show()

plt.plot(psnrtab,label='PGD')
plt.plot(psnrtab2,label='PGD2')
plt.title('PSNR')
plt.legend()
plt.show()

plt.semilogy(rtab,label='PGD')
plt.semilogy(rtab2,label='PGD2')
plt.title('Residual Norm')
plt.legend()
plt.show()

# save results
x_gspgd = x.detach().clone()
losstab_gspgd = losstab.copy()
psnrtab_gspgd = psnrtab.copy()
rtab_gspgd = rtab.copy()

x_gspgd2 = xvisu.clone()
losstab_gspgd2 = losstab2.copy()
psnrtab_gspgd2 = psnrtab2.copy()
rtab_gspgd2 = rtab2.copy()

# # save results
# tmp_x_gspgd = x.detach().clone()
# tmp_losstab_gspgd = losstab.copy()
# tmp_psnrtab_gspgd = psnrtab.copy()
# tmp_rtab_gspgd = rtab.copy()

# tmp_x_gspgd2 = xvisu.clone()
# tmp_losstab_gspgd2 = losstab2.copy()
# tmp_psnrtab_gspgd2 = psnrtab2.copy()
# tmp_rtab_gspgd2 = rtab2.copy()

## PGD with standard DRUNET denoiser

Implement the PGD algorithm
$$ x_{k+1} = D_\sigma(x_k - \tau \nabla f(x_k)) $$
by taking directly $D_\sigma$ as the DRUNET (or GS-DRUNET) denoiser.

In [None]:
D = dinv.models.DRUNet(pretrained='ckpts/drunet_color.pth').to(device)
# D = dinv.models.DnCNN(pretrained='ckpts/dncnn_sigma2_color.pth').to(device)
# D = dinv.models.DnCNN(pretrained='ckpts/dncnn_sigma2_lipschitz_color.pth').to(device)

tau = 1.9*nu**2
s = 2*nu  # strength of the denoiser

# tau = .1*nu**2
# s= .1*nu   # make things diverge for DRUNet

# initialize
x = y.clone()
normxinit = torch.linalg.vector_norm(x)

psnrtab = []
rtab = []
niter = 1000
t0 = time.time()
print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(0,niter,0,psnr(x0,y)))
for it in range(niter):
    xt = x.clone().requires_grad_(True)
    fx = f(xt)
    grad = torch.autograd.grad(outputs=fx, inputs=xt)[0]
    with torch.no_grad():
        xnew = x-tau*grad
        Dxnew = D(xnew,sigma=s)
        x = Dxnew
        xold = xnew.clone()
        Dxold = Dxnew.clone()
    psnrt = psnr(x0,x)
    r = torch.linalg.vector_norm(xt.detach()-x)/normxinit
    psnrtab.append(psnrt)
    rtab.append(r.cpu())
    
    if (it+1)%10==0:
        print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(it+1,niter,time.time()-t0,psnrt))
        viewimage(x)
    
x_pgd = x.detach().clone()
losstab_pgd = losstab.copy()
psnrtab_pgd = psnrtab.copy()
rtab_pgd = rtab.copy()

plt.plot(psnrtab)
plt.title('PSNR')
plt.show()

plt.semilogy(rtab)
plt.title('Residual Norm')
plt.show()

plt.plot(vtab)
plt.title('Denoiser Variations')
plt.show()

## PnP-PGD with Proximal Denoiser

Now, we will consider the proximal gradient descent algorithm in the other way round:
$$ \begin{cases} z_{k+1} = x_k - \frac{1}{\lambda} \nabla f(x_k) \\ x_{k+1} = D_{\sigma} (z_{k+1}) \end{cases} ,$$
where $D_{\sigma}$ is seen as a proximal operator.

In order to have proper convergence, we will consider the proximal denoiser
$$ D_\sigma = \mathsf{Id} - \nabla g_{\sigma} = \mathsf{Prox}_{\phi_\sigma} , $$
which can be seen as the prox of a certain function $\phi_\sigma$ (notice that it imposes $\tau = 1$).

Because of the relation between $x_k,z_k$, the objective function on the iterates is
$$ \frac{1}{\lambda} F(x_k) = \frac{1}{\lambda} f(x_k) + g_\sigma(z_k) - \frac{1}{2} \|x_k - z_k\|^2 .$$

Implement this PGD algorithm relying on the proximal denoiser. Along the iterations, track the function values $F(x_k)$, the PSNR and the residual norm $r_k = \|x_{k}-x_{k-1}\|$.

In [None]:
# test the proximal denoiser
D=dinv.models.GSDRUNet(pretrained='ckpts/Prox-DRUNet.ckpt', act_mode='S',train=False).to(device)
# # WARNING: be careful when loading the prox-denoiser, which has SoftPlus activation functions.

# test it on an image
x = y.clone().requires_grad_(True)
px = D.potential(x,sigma=nu)
grad = torch.autograd.grad(px,x)[0]
Dx = x.detach()-grad

viewimage(Dx)

In [None]:
D=dinv.models.GSDRUNet(pretrained='ckpts/Prox-DRUNet.ckpt', act_mode='S', train=False).to(device)

tau = 2e-4
s = 1.8*nu
lam = 1e5

# initialize
xinit = torch.clone(y)
xold = xinit.clone().requires_grad_(True)
normxinit = torch.linalg.vector_norm(xinit)

losstab = []
psnrtab = []
rtab = []

niter = 100
t0 = time.time()
print('[%4d/%4d] [%.5f s] PSNR = %.2f'%(0,niter,0,psnr(x0,y)))

for it in range(niter):


    ### TODO ###

    if (it+1)%10==0:
        print('[%4d/%4d] [%.5f s] PSNR = %.2f, F(x) = %.6f'%(it+1,niter,time.time()-t0,psnrt,Fxnew.item()))
        viewimage(xnew)

        
plt.plot(losstab)
plt.title('F(x_n)')
plt.legend()
plt.show()

plt.plot(psnrtab)
plt.title('PSNR')
plt.legend()
plt.show()

plt.semilogy(rtab)
plt.title('Residual Norm')
plt.legend()
plt.show()

# save results
x_proxpgd = x.detach().clone()
losstab_proxpgd = losstab.copy()
psnrtab_proxpgd = psnrtab.copy()
rtab_proxpgd = rtab.copy()

# # save results
# tmp_x_proxpgd = x.detach().clone()
# tmp_losstab_proxpgd = losstab.copy()
# tmp_psnrtab_proxpgd = psnrtab.copy()
# tmp_rtab_proxpgd = rtab.copy()

## Comparisons

In [None]:
# Compare the results obtained with the various algorithms implemented above.

d = 20
ind = np.arange(d,niter)

plt.figure(dpi=150)
plt.semilogy(ind,losstab_gsgd[d:],label='GSGD')
plt.semilogy(ind,losstab_gspgd[d:],label='GSPGD')
plt.semilogy(ind,losstab_proxpgd[d:],label='ProxPGD')
plt.semilogy(ind,losstab_pgd[d:],label='PGD')
# plt.semilogy(ind,losstab_gspgd2[d:],label='GSPGD2')
plt.title('F(x_n)')
plt.legend()
plt.show()

plt.figure(dpi=150)
plt.plot(ind,psnrtab_gsgd,label='GSGD')
plt.plot(ind,psnrtab_gspgd,label='GSPGD')
plt.plot(ind,psnrtab_proxpgd,label='ProxPGD')
plt.plot(ind,psnrtab_pgd,label='PGD')
# plt.plot(psnrtab_gspgd2,label='GSPGD2')
plt.title('PSNR')
plt.legend()
plt.show()

plt.figure(dpi=150)
plt.semilogy(ind,rtab_gsgd[d:],label='GSGD')
plt.semilogy(ind,rtab_gspgd[d:],label='GSPGD')
plt.semilogy(ind,rtab_proxpgd[d:],label='ProxPGD')
plt.semilogy(ind,rtab_pgd[d:],label='PGD')
# plt.semilogy(ind,rtab_gspgd2[d:],label='GSPGD2')
plt.title('Residual Norm')
plt.legend()
plt.show()

## Baseline Comparisons with explicit regularizations (Tychonov or SmoothTV)

In [None]:
# Deblurring with Tychonov Regularization
def tych_deblur(y,k,lam=0.01):
    _,_,M,N = y.shape
    xi = torch.arange(M)
    ind = (xi>M/2)
    xi[ind] = xi[ind]-M
    zeta = torch.arange(N)
    ind = (zeta>N/2)
    zeta[ind] = zeta[ind]-N
    Xi,Zeta = torch.meshgrid(xi,zeta,indexing='ij')
    Xi = Xi[None,None,:,:].to(device)
    Zeta = Zeta[None,None,:,:].to(device)
    fh = torch.conj(fk)/(torch.abs(fk)**2 + 8 * lam * (torch.sin(pi*Xi/M)**2 + torch.sin(pi*Zeta/N)**2))
    return ifft2(fft2(y)*fh).real

xtych = tych_deblur(y,k)

# Smooth TV regularization
def stv_deblur(A,y,xinit,niter=1000,lam=0.002,ep=0.01,lr=None):
    if lr is None:
        lr = 1.9/(1+lam*8/ep)
    x = xinit.clone().requires_grad_(True)
    optim = torch.optim.SGD([x], lr=lr)
    losslist = []
    for it in range(niter):
        d1 = torch.roll(x,-1,2) - x
        d2 = torch.roll(x,-1,3) - x
        reg = torch.sum(torch.sqrt(ep**2+d1**2+d2**2))
        loss = torch.sum((A(x)-y)**2) + lam*reg
        losslist.append(loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()
    return x.detach(),losslist

xtv,_ = stv_deblur(A,y,y.clone())

# # Display the results
# print('PSNR(x0,xtych) = %.2f'%psnr(x0,xtych))
# viewimage(xtych)
# print('PSNR(x0,xtv) = %.2f'%psnr(x0,xtv))
# viewimage(xtv)

plt.figure(dpi=180)
plt.subplot(2,2,1)
plt.imshow(tensor2im(x0), cmap='gray')
plt.title('Original',fontsize=8)
# plt.imshow(tensor2im(y), cmap='gray')
# plt.title('Degraded \n PSNR='+str2(psnr(x0,y)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,2)
plt.imshow(tensor2im(xtych), cmap='gray')
plt.title('Tychonov \n PSNR='+str2(psnr(x0,xtych)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,3)
plt.imshow(tensor2im(xtv), cmap='gray')
plt.title('SmoothTV \n PSNR='+str2(psnr(x0,xtv)),fontsize=8)
plt.axis('off')
plt.subplot(2,2,4)
plt.imshow(tensor2im(x_gspgd), cmap='gray')
plt.title('PnP-PGD \n PSNR='+str2(psnr(x0,x_gspgd)),fontsize=8)
plt.axis('off')
plt.show()

<br/><br/><br/>

# Exercice 2: Image Super-resolution

<br/><br/>
Use your favorite plug-and-play algorithm to tackle image super-resolution in a stable way.

The direct model $Ax = (k*x)_{\downarrow p}$ is a composition of an anti-aliasing filter (e.g. Butterworth filter, see below) with a downsampling step with stride $p$.

In [None]:
# Adjust the framework to address super-resolution with smoothed TV
# For anti-aliasing, you may use the Butterworth filter of order n and cut-off frequency fc 
#   given below

# fc is the cut-off frequency normalized in (0,1)
def butterworth(M,N,fc=.5,order=5):
    xi = torch.arange(M)
    ind = (xi>M/2)
    xi[ind] = xi[ind]-M
    zeta = torch.arange(N)
    ind = (zeta>N/2)
    zeta[ind] = zeta[ind]-N
    Xi,Zeta = torch.meshgrid(xi,zeta,indexing='ij')
    Xi = Xi[None,None,:,:].to(device)
    Zeta = Zeta[None,None,:,:].to(device)
    bf1 = 1/torch.sqrt(1+(Xi/(M*fc/2))**(2*order))
    bf2 = 1/torch.sqrt(1+(Zeta/(N*fc/2))**(2*order))
    return bf1*bf2
    
bf = butterworth(M,N)
viewimage(bf)

# Use example:
bf = butterworth(x0.shape[2],x0.shape[3],fc=.5)
x0f = ifft2(bf*fft2(x0)).real

viewimage(x0)
viewimage(x0f)