This project is based on Diffusion Posterior Sampling implementation (https://github.com/DPS2022/diffusion-posterior-sampling.git), to leverage the structure of the code with respect to Guided Diffusion (https://github.com/openai/guided-diffusion.git) and ILVR-ADM (https://github.com/jychoi118/ilvr_adm.git)
Both ILVR-ADM and DPS are re-implementations of Guided Diffusion.

I have focused on MPGD without Projection on the task of Super Resolution.

The full code can be found here:

In [1]:
!git clone https://github.com/giorgiacucino/MPGD.git

Cloning into 'MPGD'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 46 (delta 9), reused 42 (delta 7), pack-reused 0[K
Receiving objects: 100% (46/46), 16.18 MiB | 15.41 MiB/s, done.
Resolving deltas: 100% (9/9), done.


# Implementation of MPGD shortcut

To implement Manifold Preserving Guided Diffusion, we have to re-implement the loop that denoises the initial full noise sample e.g. the loop from x_t to x_0

In [None]:
class MPGD(SpacedDiffusion):
    def p_sample_loop(self,
                      model,
                      x_start,
                      measurement,
                      measurement_cond_fn,
                      record,
                      save_root):
        img = x_start
        device = x_start.device

        pbar = tqdm(list(range(self.num_timesteps))[::-1])
        for idx in pbar:
            time = torch.tensor([idx] * img.shape[0], device=device)
            #Predicts the model output e.g. eps_theta(x_t, t)
            model_output = model(img, self._scale_timesteps(time))

            # In the case of "learned" variance, model will give twice channels.
            if model_output.shape[1] == 2 * img.shape[1]:
                model_output, model_var_values = torch.split(model_output, img.shape[1], dim=1)
            else:
                model_var_values = model_output

            #Predicts x_0|t
            pred_xstart = self.mean_processor.predict_xstart(x_t=img, t=time, eps=model_output)

            #Clone x_0|t to require grad on it
            x_0_t = pred_xstart.clone()
            x_0_t = x_0_t.requires_grad_()

            #Calculate the scaling factor c_t based on the alpha_bar and alpha_bar_prev
            #This implementation is based on the formula at page 17 of the paper
            alpha_bar = extract_and_expand(self.alphas_cumprod, time, img)
            alpha_bar_prev = extract_and_expand(self.alphas_cumprod_prev, time, img)
            scale = 1.0 / (torch.sqrt(alpha_bar * alpha_bar_prev))

            #Calculates the gradient with respect to x_0_t and updates the predicted x_start
            #using the measurement_cond_fn, which can be set using a configuration file.
            #(see "Conditioning method" section for more details)
            pred_xstart, distance = measurement_cond_fn(x_t=pred_xstart,
                                                  measurement=measurement,
                                                  x_prev=x_0_t,
                                                  x_0_hat=x_0_t,
                                                  idx=idx,
                                                  timesteps=self.num_timesteps,
                                                  scale=scale)
            x_0_t = x_0_t.detach_()
            with torch.no_grad():
                #Calculates sigma with eta=0.5 (as in the implementation of the paper)
                sigma = (0.5 * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * torch.sqrt(1 - alpha_bar / alpha_bar_prev))
                noise = torch.randn_like(img)
                #Calculates x_t = x_0|t * sqrt(alpha_bar_prev) + sqrt(1 - alpha_bar_prev - (sigma ** 2)) * model_output
                #(as in the implementation of the paper)
                img = (pred_xstart * torch.sqrt(alpha_bar_prev) + torch.sqrt(1 - alpha_bar_prev - (sigma ** 2)) * model_output)
                if time != 0:
                    img += sigma * noise
            pbar.set_postfix({'distance': distance.item()}, refresh=False)
            if record:
                if idx % 10 == 0:
                    #Outputs files for debugging purposes (see in "Results" section for more details)
                    file_path = os.path.join(save_root, f"progress/img_{str(idx).zfill(4)}.png")
                    file_path2 = os.path.join(save_root, f"progress/x_{str(idx).zfill(4)}.png")
                    file_path3 = os.path.join(save_root, f"progress/eps_{str(idx).zfill(4)}.png")
                    file_path4 = os.path.join(save_root, f"progress/pred_{str(idx).zfill(4)}.png")
                    plt.imsave(file_path, clear_color(img))
                    plt.imsave(file_path2, clear_color(x_0_t))
                    plt.imsave(file_path3, clear_color(model_output))
                    plt.imsave(file_path4, clear_color(pred_xstart))

        return img

# Conditioning Method

The measurement_cond_fn is set to this in the configuration file "configs/mpgd_super_resolution_config.yaml"

In [None]:
@register_conditioning_method(name='mpg')
class ManifoldPreservingGradient(ConditioningMethod):
    def __init__(self, operator, noiser, **kwargs):
        super().__init__(operator, noiser)
        self.scale = kwargs.get('scale', 1.0)

    def conditioning(self, x_prev, x_t, x_0_hat, measurement, idx, timesteps, **kwargs):
        #Calculate the gradient and the norm
        norm_grad, norm = self.grad_and_value(x_prev=x_prev, x_0_hat=x_0_hat, measurement=measurement, **kwargs)
        x_t -= norm_grad * self.scale

        #Project the data in the first half of the sampling process (different from the paper)
        if (idx > timesteps/2):
          x_t = self.project(data=x_t, noisy_measurement=measurement, **kwargs)
        return x_t, norm

This is different from the paper because after the update of the clean data estimation I project it.

I did this because empirically I found out that:

1.   if the data is not projected, it tends to diverge from the original input
2.   if the data is projected for the whole process, it tends to preserve some noise from the measurement

Moreover, projecting for the whole process has a negative impact on the model output calculated on the next iteration

# How to run the code

First of all, you will need to clone the repository with the code from github

In [1]:
!git clone https://github.com/giorgiacucino/MPGD.git

Cloning into 'MPGD'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (39/39), done.[K
remote: Total 46 (delta 9), reused 42 (delta 7), pack-reused 0[K
Receiving objects: 100% (46/46), 16.18 MiB | 17.02 MiB/s, done.
Resolving deltas: 100% (9/9), done.


Next, create a ./models directory to store the pretrained model

In [2]:
cd /content/MPGD

/content/MPGD


In [3]:
mkdir /content/MPGD/models

You can get the pretrained model from this [link](https://drive.google.com/drive/folders/1jElnRoFv7b31fG0v6pTSQkelbSX3xGZh)

If the model is saved on Google Drive, run the next cell

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
!cp /content/drive/MyDrive/ffhq_10m.pt /content/MPGD/models

In [6]:
!git clone https://github.com/VinAIResearch/blur-kernel-space-exploring bkse

!git clone https://github.com/LeviBorodenko/motionblur motionblur

Cloning into 'bkse'...
remote: Enumerating objects: 565, done.[K
remote: Counting objects: 100% (565/565), done.[K
remote: Compressing objects: 100% (316/316), done.[K
remote: Total 565 (delta 327), reused 461 (delta 232), pack-reused 0[K
Receiving objects: 100% (565/565), 1.04 MiB | 10.88 MiB/s, done.
Resolving deltas: 100% (327/327), done.
Cloning into 'motionblur'...
remote: Enumerating objects: 36, done.[K
remote: Total 36 (delta 0), reused 0 (delta 0), pack-reused 36[K
Receiving objects: 100% (36/36), 511.08 KiB | 5.16 MiB/s, done.
Resolving deltas: 100% (12/12), done.


In [7]:
import torch

if torch.cuda.is_available():
  device = torch.device("cuda:0")
else:
  device = torch.device("cpu")

In [21]:
import yaml

#From DPS implementation (sample_condition.py)
def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

#Change the configuration files here
model_config = load_yaml("./configs/model_config.yaml")
diffusion_config = load_yaml("./configs/mpgd_diffusion_config.yaml")
task_config = load_yaml("./configs/mpgd_super_resolution_config.yaml")

import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from functools import partial
import os

from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger

def start_sampling(model_config, diffusion_config, task_config, device):
    # Load model
    model = create_model(**model_config)
    model = model.to(device)
    model.eval()

    # Prepare Operator and noise
    measure_config = task_config['measurement']
    operator = get_operator(device=device, **measure_config['operator'])
    noiser = get_noise(**measure_config['noise'])

    # Prepare conditioning method
    cond_config = task_config['conditioning']
    cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
    measurement_cond_fn = cond_method.conditioning

    # Load diffusion sampler
    sampler = create_sampler(**diffusion_config)
    sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn)

    # Working directory
    out_path = os.path.join("./results", measure_config['operator']['name'])
    os.makedirs(out_path, exist_ok=True)
    for img_dir in ['input', 'recon', 'progress', 'label']:
        os.makedirs(os.path.join(out_path, img_dir), exist_ok=True)

    # Prepare dataloader
    data_config = task_config['data']
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    dataset = get_dataset(**data_config, transforms=transform)
    loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

    # Exception) In case of inpainting, we need to generate a mask
    if measure_config['operator']['name'] == 'inpainting':
        mask_gen = mask_generator(
            **measure_config['mask_opt']
        )

    # Do Inference
    for i, ref_img in enumerate(loader):
        fname = str(i).zfill(5) + '.png'
        ref_img = ref_img.to(device)

        # Exception) In case of inpainging,
        if measure_config['operator'] ['name'] == 'inpainting':
            mask = mask_gen(ref_img)
            mask = mask[:, 0, :, :].unsqueeze(dim=0)
            measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
            sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)

            # Forward measurement model (Ax + n)
            y = operator.forward(ref_img, mask=mask)
            y_n = noiser(y)

        else:
            # Forward measurement model (Ax + n)
            y = operator.forward(ref_img)
            y_n = noiser(y)

        # Sampling
        x_start = torch.randn(ref_img.shape, device=device).requires_grad_()
        sample = sample_fn(x_start=x_start, measurement=y_n, record=True, save_root=out_path)

        plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n))
        plt.imsave(os.path.join(out_path, 'label', fname), clear_color(ref_img))
        plt.imsave(os.path.join(out_path, 'recon', fname), clear_color(sample))

In [31]:
start_sampling(model_config, diffusion_config, task_config, device)

  0%|          | 0/1000 [00:00<?, ?it/s]

# Results

To visualize the results, I have created a function that generates the gifs for some of the data used in the loop.

The result can be found in the gif img.gif

In [33]:
import os
from PIL import Image
import torch
import imageio

def generate_gif(image_folder, output_gif_path, start):
    image_files = [f for f in os.listdir(image_folder) if f.startswith(start) and f.endswith(".png")]

    image_files.sort(key=lambda x: int(x.split('_')[1].split('.')[0]))

    images = []
    for image_file in image_files:
        image_path = os.path.join(image_folder, image_file)
        image = Image.open(image_path)
        images.append(image)

    images = images[::-1]
    gif_path = output_gif_path
    imageio.mimsave(gif_path, images, duration=0.5)

input_folder = "./results/super_resolution/progress"

!mkdir "./gifs"

output_folder = "./gifs/"
x_0 = "x_0.gif"
img = "img.gif"
pred = "pred.gif"
eps = "eps.gif"

generate_gif(input_folder, output_folder + x_0, "x_")
generate_gif(input_folder, output_folder + img, "img_")
generate_gif(input_folder, output_folder + pred, "pred_")
generate_gif(input_folder, output_folder + eps, "eps_")
