In [None]:
!pip install lpips
!pip install scikit-image

In [None]:
!pip install git+https://github.com/deepinv/deepinv.git#egg=deepinv
import deepinv as dinv

In [None]:
import torch
import torchvision
import numpy as np
import time
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import lpips
from skimage.metrics import structural_similarity as ssim
from utils import pilimg_to_tensor, display_as_pilimg, inpainting_operator, psnr, blurring_operator, downsampling_operator, transposed_blurring_op
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device:", device)

In [None]:
# Import PnP-SGD
from pnpsgd import PnPSGD
nu = 4/255
D=dinv.models.GSDRUNet(pretrained='download').to(device)

In [None]:
idx = 11
    
x_true = torch.tensor(plt.imread('ffhq256/'+str(idx).zfill(5)+'.png'),device=device)
x_true = x_true.permute(2,0,1).unsqueeze(0)
print("Original image :")
display_as_pilimg(x_true)
sigma_noise = 0.01
y = blurring_operator(x_true.clone(), device = device) + 0.01 * torch.randn_like(x_true, device=device)
print("Blurring operator + noise :")
display_as_pilimg(y)

In [None]:
x = PnPSGD(x_true, y, nu, blurring_operator, transposed_blurring_op, D, device)
display_as_pilimg(x)

For noise level = 0.05

In [None]:
sigma_noise = 0.05
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true = torch.tensor(plt.imread('ffhq256/'+str(idx).zfill(5)+'.png'),device=device)
    x_true = x_true.permute(2,0,1).unsqueeze(0)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true,save = True, filename='results_pnp_05/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device) + sigma_noise * torch.randn_like(x_true, device = device)
    print("Degraded image")
    display_as_pilimg(y,save = True, filename='results_pnp_05/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = PnPSGD(x_true, y, nu, blurring_operator, transposed_blurring_op, D, device)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename = 'results_pnp_05/restaured'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

For noise level 0.1 : 

In [None]:
sigma_noise = 0.1
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true = torch.tensor(plt.imread('ffhq256/'+str(idx).zfill(5)+'.png'),device=device)
    x_true = x_true.permute(2,0,1).unsqueeze(0)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true,save = True, filename='results_pnp_1/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device) + sigma_noise * torch.randn_like(x_true, device = device)
    print("Degraded image")
    display_as_pilimg(y,save = True, filename='results_pnp_1/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = PnPSGD(x_true, y, nu, blurring_operator, transposed_blurring_op, D, device)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename = 'results_pnp_1/restaured'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

Run the experiments without additive noise : 

In [None]:

avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true = torch.tensor(plt.imread('ffhq256/'+str(idx).zfill(5)+'.png'),device=device)
    x_true = x_true.permute(2,0,1).unsqueeze(0)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true,save = True, filename='results_pnp_0/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device)
    print("Degraded image")
    display_as_pilimg(y,save = True, filename='results_pnp_0/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = PnPSGD(x_true, y, nu, blurring_operator, transposed_blurring_op, D, device)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename = 'results_pnp_0/restaured'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))