In [1]:
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import numpy as np
import cv2

from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler, get_named_beta_schedule
from util.img_utils import clear_color
from tasks.motion_blur import MotionBlurCircular
from tqdm import tqdm

# Device setting
device_str = f"cuda:{0}" if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)  

In [2]:
# Load model
model = create_model(
    image_size= 256,
    num_channels= 128,
    num_res_blocks= 1,
    channel_mult= '',
    learn_sigma= True,
    class_cond= False,
    use_checkpoint= False,
    attention_resolutions= "16",
    num_heads= 4,
    num_head_channels= 64,
    num_heads_upsample= -1,
    use_scale_shift_norm= True,
    dropout= 0.0,
    resblock_updown= True,
    use_fp16= False,
    use_new_attention_order= False,
    model_path= 'models/ffhq_10m.pt')

model = model.to(device)
model.eval()

# Load diffusion sampler
sampler = create_sampler(
    sampler= 'ddpm',
    steps= 1000,
    noise_schedule= 'linear',
    model_mean_type= 'epsilon',
    model_var_type= 'learned_range',
    dynamic_threshold= False,
    clip_denoised= True,
    rescale_timesteps= False)

In [3]:
img_path = '00001'

#load original image
X_full = cv2.imread('data/samples_ffhq/'+img_path+'.png')
X_full = cv2.resize(X_full, dsize=(256, 256), interpolation=cv2.INTER_CUBIC)
X_full = cv2.cvtColor(X_full, cv2.COLOR_BGR2RGB)
X0 = X_full[:,:,0]
X1 = X_full[:,:,1]
X2 = X_full[:,:,2]
N = X1.shape
N_full = X_full.shape

#set BSNR
BSNR = 60 # SNR expressed in decibels
P_signal = X_full.var() # signal power
sigma = np.sqrt((P_signal/10**(BSNR/10))) # standard deviation of the noise

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

operator = MotionBlurCircular(kernel_size= 61, intensity= 0.5, channels= 3, img_dim= 256, device=device)

In [5]:
X = transform(X_full).to(device)
N = X.shape
Y = operator.forward(X) + sigma*torch.randn(X.shape).to(device)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 20))

axes[0].imshow(clear_color(X))
axes[0].set_title('True image')
axes[0].axis('off');

axes[1].imshow(clear_color(Y))
axes[1].set_title('Noisy image')
axes[1].axis('off');

Z = torch.randn(N, device=device)
X_rec = operator.proximal_generator(Z, Y, sigma, rho=0.1)

axes[2].imshow(clear_color(X_rec));
axes[2].set_title('First likelihood step')
axes[2].axis('off');

In [8]:
betas = get_named_beta_schedule('linear', 1000)
alphas = np.cumsum(betas) / max(np.cumsum(betas))

def estimate_time(value, array=alphas):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx

In [9]:
def compute_last_diff_step(t_start, N_bi):
    if t < N_bi:
        t_stop = int(t_start* 0.7)
    else:
        t_stop = 0
    return t_stop

In [10]:
N_MC = 23
N_bi = 20
rho = 0.1
rho_decay_rate = 0.8

# Initialization
# define matrices to store the iterates
X_MC = torch.zeros(size = (3, N[1],N[2],N_MC+1), device=device)
Z_MC = torch.zeros(size = (3, N[1],N[2],N_MC+1), device=device)
# initialize the latter matrices
X_MC[:,:,:,0] = torch.randn(N, device=device)
Z_MC[:,:,:,0] = torch.randn(N, device=device)

# Gibbs sampling
for t in tqdm(range(N_MC)):
    # likelihood step
    X_MC[:,:,:,t+1] = operator.proximal_generator(Z_MC[:,:,:,t], Y, sigma, rho=rho)
    
    #est_sigma = estimate_noise(clear_color(X_MC[:,:,:,t+1]))
    rho_iter = rho * (rho_decay_rate**t)
    t_start = estimate_time(rho_iter)
    t_stop = compute_last_diff_step(t_start, N_bi)
        
    # prior step
    Z_MC[:,:,:,t+1] = sampler.diffuse_back(x=X_MC[:,:,:,t+1].unsqueeze(0), model=model, t_start=1000 - t_start, t_end=1000-t_stop).squeeze(0)

100%|██████████| 23/23 [00:44<00:00,  1.94s/it]


In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 20))

axes[0].imshow(clear_color(X))
axes[0].set_title('True image')
axes[0].axis('off');

axes[1].imshow(clear_color(Y))
axes[1].set_title('Noisy image')
axes[1].axis('off');

axes[2].imshow(clear_color(torch.mean(Z_MC[:,:,:,N_bi:N_MC], axis=-1)));
axes[2].set_title('Z Reconstructed image')
axes[2].axis('off');

axes[3].imshow(clear_color(torch.mean(X_MC[:,:,:,N_bi:N_MC], axis=-1)));
axes[3].set_title('X Reconstructed image')
axes[3].axis('off');

In [None]:
for i in range(N_MC):
    fig, axes = plt.subplots(1, 2, figsize=(20, 20))
    axes[0].imshow(clear_color(Z_MC[:,:,:,i]));
    axes[1].imshow(clear_color(X_MC[:,:,:,i]));
    plt.show();