In [1]:
import torch
import matplotlib.pyplot as plt
import yaml
import argparse

from data import get_dataset, get_dataloader
from denoisers.DnCNN.get_dncnn import create_model_DnCNN
from denoisers.hierarquicalVAE.get_VAE import create_model_nvae
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler
from util.img_utils import clear_color
from tasks import create_operator
from gibbs_sampler import GibbsSampler


  from .autonotebook import tqdm as notebook_tqdm
  if sampler is None and model_type is 'DDMP':


In [None]:
# Check device
device_str = "cuda:" if torch.cuda.is_available() else 'cpu'
device = torch.device(device_str)

In [None]:
# Here, we will focus on gaussian blur operator
operator_config = {
    "name": "gaussian_blur",
    "kernel_size": 61,
    "intensity": 3.0,
    "channels": 3,
    "img_dim": 256
}

# Create the linear operator
H_gaussian = create_operator(**operator_config, device=device)


In [None]:
# Get the test dataset
get_dataset("ffhq", root="data/samples_ffhq")
dataset = get_dataset("ffhq", root="data/samples_ffhq")
num_test_images = len(dataset)
dataloader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

## With DDMP model as a denoiser

In [None]:
# Select configurations for the sampler in the case of diffusion
diffusion_config = {
    "sampler": "ddpm",
    "steps": 1000,
    "noise_schedule": "linear",
    "model_mean_type": "epsilon",
    "model_var_type": "learned_range",
    "dynamic_threshold": False,
    "clip_denoised": True,
    "rescale_timesteps": False,
    "timestep_respacing": 1000
}

# Create sampler to be used in the case of diffusion
diffusion_sampler = create_sampler(**diffusion_config)

In [None]:
# Now, get the pre-trained diffusion model
possible_diffusion_models = {'ffhq_10m':'models/ffhq_10m.pt', 'imagenet':'models/imagenet256.pt'}

# Base configurations

model_config = {
    "image_size": 256,
    "num_channels": 128,
    "num_res_blocks": 1,
    "channel_mult": "",
    "learn_sigma": True,
    "class_cond": False,
    "use_checkpoint": False,
    "attention_resolutions": 16,
    "num_heads": 4,
    "num_head_channels": 64,
    "num_heads_upsample": -1,
    "use_scale_shift_norm": True,
    "dropout": 0.0,
    "resblock_updown": True,
    "use_fp16": False,
    "use_new_attention_order": False,
    "model_path": possible_diffusion_models['ffhq_10m'] # With ffhq_10m
}

# Init and loag pretrained model, and put in inference mode
model_type = 'DDMP'
model_DDMP = create_model(**model_config)
model_DDMP.to(device)
model_DDMP.eval()

## Sample

In here, we use the fact that we divide our problem in (1) sampling x such that x is close to z and (2) sampling z as a denoised image

In [None]:
# Lets see with just one test sample
X = next(iter(dataloader)).to(device)

Y = H_gaussian.forward(X)
sigma = torch.tensor(0.05).to(device)
# Creating the noisy image Y = HX + n
Y = Y + sigma*torch.randn(X.shape).to(device)

# Plot noisy image
fig, ax = plt.subplots(figsize=(20, 20))
ax.imshow(clear_color(Y))
ax.set_title('Noisy image')
ax.axis('off')
plt.savefig(f"results/image_noisy.png", dpi=200, bbox_inches='tight')

plt.show()


In [None]:
# Execute sampling
N_bi = 20  # Burn-in itereations (Discart)
N_MC = 23  # Total number of iterations
gibbs = GibbsSampler(
                     Y=Y,
                     sigma=sigma,
                     operator=H_gaussian,
                     sampler=diffusion_sampler,
                     model=model_DDMP,
                     model_type=model_type,
                     device=device,
                     N_MC=23,
                     N_bi=20,
                     rho=0.1,
                     rho_decay_rate=0.8,
                     plot_process = 5)

X_MC, Z_MC = gibbs.run()

fig, axes = plt.subplots(1, 4, figsize=(20, 20))

axes[0].imshow(clear_color(X))
axes[0].set_title('True image')
axes[0].axis('off')

axes[1].imshow(clear_color(Y))
axes[1].set_title('Noisy image')
axes[1].axis('off')

axes[2].imshow(clear_color(torch.mean(Z_MC[:,:,:,N_bi:N_MC], axis=-1)))
axes[2].set_title('Z Reconstructed image')
axes[2].axis('off')

axes[3].imshow(clear_color(torch.mean(X_MC[:,:,:,N_bi:N_MC], axis=-1)))
axes[3].set_title('X Reconstructed image')
axes[3].axis('off')

plt.savefig(f"results/example_test.png", dpi=200, bbox_inches='tight')

## Comparing with the output after one diffusion step back

Here, we are estimating the noise level of the original image Y, and then passing it to backward diffusion to recover an estimative of X