In [None]:
!pip install lpips
!pip install scikit-image
!pip install matplotlib

In [None]:
import torch
import torchvision
import numpy as np
import time

import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from guided_diffusion.unet import create_model

from utils import pilimg_to_tensor, display_as_pilimg, inpainting_operator, psnr, blurring_operator, downsampling_operator, transposed_blurring_op
from dps import DPS
import lpips
from skimage.metrics import structural_similarity as ssim
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device:", device)

In [None]:
model_config = {'image_size': 256,
                'num_channels': 128,
                'num_res_blocks': 1,
                'channel_mult': '',
                'learn_sigma': True,
                'class_cond': False,
                'use_checkpoint': False,
                'attention_resolutions': 16,
                'num_heads': 4,
                'num_head_channels': 64,
                'num_heads_upsample': -1,
                'use_scale_shift_norm': True,
                'dropout': 0.0,
                'resblock_updown': True,
                'use_fp16': False,
                'use_new_attention_order': False,
                'model_path': 'ffhq_10m.pt'}
model = create_model(**model_config)
model = model.to(device)
model.eval()

num_timesteps = 1000
dps = DPS(model, num_timesteps, device)

One example of inpainting operator : 

In [None]:
idx = 11
x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
x_true = pilimg_to_tensor(x_true_pil, device)
print("original image", str(idx).zfill(5)+'.png')
display_as_pilimg(x_true)

y = inpainting_operator(x_true.clone(),256, 256, device)
print("Inpainting operator :")
display_as_pilimg(y)

In [None]:
x = dps.posterior_sampling(inpainting_operator, y, x_true, show_steps=True, vis_y=None)
print("Restaured image : ")
display_as_pilimg(x)

Run the experiments : 

In [None]:
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true, save = True, filename='results_dps_inp/true'+str(idx)+'.png')
    y = inpainting_operator(x_true.clone(),256, 256, device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_inp/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(inpainting_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename='results_dps_inp/restaured'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

Run the experiments for blurring operator + noise level 0.05: 

In [None]:
sigma_noise = 0.05
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true, save = True, filename='results_dps_blur05/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device) + sigma_noise * torch.randn_like(x_true, device = device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_blur05/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(blurring_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x,save = True, filename='results_dps_blur05/restored'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

Run the experiements for gaussian blur and noise level 0.1 : 

In [None]:
sigma_noise = 0.1
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true, save = True, filename='results_dps_blur1/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device) + sigma_noise * torch.randn_like(x_true, device = device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_blur1/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(blurring_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x,save = True, filename='results_dps_blur1/restored'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

Run the experiments for super-resolution (x4) : 

In [None]:
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true,save = True, filename='results_dps_sup/true'+str(idx)+'.png')

    y = downsampling_operator(x_true.clone(), device = device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_sup/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(blurring_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename='results_dps_sup/restaured'+str(idx)+'.png')
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

Run the experiments for SR4 + noise level 0.05

In [None]:
sigma_noise = 0.05
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true,save = True, filename='results_dps_sup05/true'+str(idx)+'.png')

    y = downsampling_operator(x_true.clone(), device = device) + sigma_noise * torch.randn_like(x_true, device=device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_sup05/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(blurring_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x, save = True, filename='results_dps_sup05/restaured'+str(idx)+'.png')
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))

In [None]:
avg_lpips = []
avg_psnr = []
avg_ssim = []
avg_tps = []

loss_fn = lpips.LPIPS(net='alex').to(device)

for idx in range(25):
    x_true_pil = Image.open('ffhq256/'+str(idx).zfill(5)+'.png')
    x_true = pilimg_to_tensor(x_true_pil, device)
    print("Original image", str(idx).zfill(5)+'.png')
    display_as_pilimg(x_true, save = True, filename='results_dps_blur/true'+str(idx)+'.png')

    y = blurring_operator(x_true.clone(), device = device)
    print("Degraded image")
    display_as_pilimg(y, save = True, filename='results_dps_blur/degraded'+str(idx)+'.png')

    t0 = time.time()
    x = dps.posterior_sampling(blurring_operator, y, x_true, show_steps=True, vis_y=None)
    t1 = time.time()-t0

    xlpips = loss_fn(x, x_true).item()
    xpsnr = psnr(x, x_true)
    xssim = ssim(x.squeeze(0).cpu().detach().numpy(), x_true.squeeze(0).cpu().detach().numpy(), win_size=3, data_range=1.0, size_average=True)
    print("LPIPS =", xlpips, ", PSNR =", xpsnr, ", SSIM = ", xssim, ", Execution time = ", t1)
    print("Restaured image")
    display_as_pilimg(x,save = True, filename='results_dps_blur/restored'+str(idx)+'.png')
    
    avg_lpips.append(xlpips)
    avg_psnr.append(xpsnr)
    avg_ssim.append(xssim)
    avg_tps.append(t1)

In [None]:
print("Average LPIPS : ", np.mean(avg_lpips))
print("Average PSNR : ", np.mean(avg_psnr))
print("Average SSIM : ", np.mean(avg_ssim))
print("Average Execution time : ", np.mean(avg_tps))

In [None]:
print("Variance LPIPS : ", np.var(avg_lpips))
print("Variance PSNR : ", np.var(avg_psnr))
print("Variance SSIM : ", np.var(avg_ssim))
print("Variance Execution time : ", np.var(avg_tps))