<a href="https://colab.research.google.com/github/mlvlab/data303/blob/main/Image_Restoraion_by_InverseProblems.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 0. Download git repo and Prepare requirements

In [1]:
!git clone https://github.com/wyhuai/DDNM.git

Cloning into 'DDNM'...
remote: Enumerating objects: 607, done.[K
remote: Counting objects: 100% (184/184), done.[K
remote: Compressing objects: 100% (116/116), done.[K
remote: Total 607 (delta 156), reused 68 (delta 68), pack-reused 423[K
Receiving objects: 100% (607/607), 14.30 MiB | 17.12 MiB/s, done.
Resolving deltas: 100% (276/276), done.


In [2]:
cd DDNM

/content/DDNM


In [3]:
import argparse
import traceback
import shutil
import logging
import yaml
import sys
import os
import torch
import numpy as np
import torch.utils.tensorboard as tb
import random
import torchvision.utils as tvu

In [None]:
def dict2namespace(config):
    namespace = argparse.Namespace()
    for key, value in config.items():
        if isinstance(value, dict):
            new_value = dict2namespace(value)
        else:
            new_value = value
        setattr(namespace, key, new_value)
    return namespace

In [None]:
def parse_args_and_config(args):

    config = ''
    # parse config file
    with open(os.path.join("configs", args.config), "r") as f:
        config = yaml.safe_load(f)
    new_config = dict2namespace(config)

    if 'dataset' in config['data'].keys():
        dataset_name = config['data']['dataset']
    elif 'name' in config['data'].keys():
        dataset_name = config['data']['name']


    time_info = 'Tsample{}_len{}_repeat{}'.format(args.T_sampling, args.travel_length, args.travel_repeat)

    args.new_image_folder = os.path.join(args.image_folder, f'{dataset_name}_{args.deg}/{time_info}/eta{args.eta}_etaB{args.etaB}_sigmay{args.sigma_y}')
    new_config.deg = args.deg

    level = getattr(logging, args.verbose.upper(), None)
    if not isinstance(level, int):
        raise ValueError("level {} not supported".format(args.verbose))

    handler1 = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(levelname)s - %(filename)s - %(asctime)s - %(message)s"
    )
    handler1.setFormatter(formatter)
    logger = logging.getLogger()
    logger.addHandler(handler1)
    logger.setLevel(level)


    if not os.path.exists(args.new_image_folder):
        os.makedirs(args.new_image_folder)

    # add device
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info("Using device: {}".format(device))
    new_config.device = device
    args.device = device

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    torch.backends.cudnn.benchmark = True

    return args, new_config


In [None]:
def logit_transform(image, lam=1e-6):
    image = lam + (1 - 2 * lam) * image
    return torch.log(image) - torch.log1p(-image)

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a

def data_transform(config, X):
    if config.data.uniform_dequantization:
        X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
    if config.data.gaussian_dequantization:
        X = X + torch.randn_like(X) * 0.01

    if config.data.rescaled:
        X = 2 * X - 1.0
    elif config.data.logit_transform:
        X = logit_transform(X)

    if hasattr(config, "image_mean"):
        return X - config.image_mean.to(X.device)[None, ...]

    return X

def inverse_data_transform(x):
    x = (x + 1.0) / 2.0
    return torch.clamp(x, 0.0, 1.0)

def inverse_data_transform_config(config, X):
    if hasattr(config, "image_mean"):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.0) / 2.0
    return torch.clamp(X, 0.0, 1.0)

def _check_times(times, t_0, T_sampling):
    # Check end
    assert times[0] > times[1], (times[0], times[1])

    # Check beginning
    assert times[-1] == -1, times[-1]

    # Steplength = 1
    for t_last, t_cur in zip(times[:-1], times[1:]):
        assert abs(t_last - t_cur) == 1, (t_last, t_cur)

    # Value range
    for t in times:
        assert t >= t_0, (t, t_0)
        assert t <= T_sampling, (t, T_sampling)

In [None]:
def get_gaussian_noisy_img(img, noise_level):
    return img + torch.randn_like(img).cuda() * noise_level

def MeanUpsample(x, scale):
    n, c, h, w = x.shape
    out = torch.zeros(n, c, h, scale, w, scale).to(x.device) + x.view(n,c,h,1,w,1)
    out = out.view(n, c, scale*h, scale*w)
    return out

def color2gray(x):
    coef=1/3
    x = x[:,0,:,:] * coef + x[:,1,:,:]*coef +  x[:,2,:,:]*coef
    return x.repeat(1,3,1,1)

def gray2color(x):
    x = x[:,0,:,:]
    coef=1/3
    base = coef**2 + coef**2 + coef**2
    return torch.stack((x*coef/base, x*coef/base, x*coef/base), 1)

# 1. Define Diffusion Process

## 1-1. Noise scheduling function
* Linear - DDPM scheduling

In [None]:
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas


## 1-2. Class Diffusion
* Sample function: Download pretrained weight and Run DDNM algorithm

In [None]:
import torch.utils.data as data
import tqdm
from datasets import get_dataset, data_transform, inverse_data_transform
from guided_diffusion.models import Model
from guided_diffusion.script_util import create_model
from functions.ckpt_util import get_ckpt_path, download

class Diffusion(object):
    def __init__(self, args, config, device=None):
        self.args = args
        self.config = config
        if device is None:
            device = (torch.device("cuda") if torch.cuda.is_available()
                      else torch.device("cpu"))
        self.device = device

        self.model_var_type = config.model.var_type
        betas = get_beta_schedule(
            beta_schedule=config.diffusion.beta_schedule,
            beta_start=config.diffusion.beta_start,
            beta_end=config.diffusion.beta_end,
            num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
        )
        betas = self.betas = torch.from_numpy(betas).float().to(self.device)
        self.num_timesteps = betas.shape[0]

        alphas = 1.0 - betas
        alphas_cumprod = alphas.cumprod(dim=0)
        alphas_cumprod_prev = torch.cat(
            [torch.ones(1).to(device), alphas_cumprod[:-1]], dim=0
        )
        self.alphas_cumprod_prev = alphas_cumprod_prev
        posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
        if self.model_var_type == "fixedlarge":
            self.logvar = betas.log()
        elif self.model_var_type == "fixedsmall":
            self.logvar = posterior_variance.clamp(min=1e-20).log()


    def sample(self, A_funcs):
        cls_fn = None

        # Download Pretrained model checkpoint
        if self.config.model.type == 'openai':
            config_dict = vars(self.config.model)
            model = create_model(**config_dict)
            if self.config.model.use_fp16:
                model.convert_to_fp16()

            ckpt = os.path.join(self.args.exp, "logs/imagenet/256x256_diffusion_uncond.pt")
            if not os.path.exists(ckpt):
                download(
                    'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt',
                    ckpt)

            model.load_state_dict(torch.load(ckpt, map_location=self.device))
            model.to(self.device)
            model.eval()
            model = torch.nn.DataParallel(model)


        print('Run SVD-based DDNM.',
              f'{self.config.time_travel.T_sampling} sampling steps.',
              f'travel_length = {self.config.time_travel.travel_length},',
              f'travel_repeat = {self.config.time_travel.travel_repeat}.',
              f'Task: {self.args.deg}.'
             )
        self.svd_based_ddnm_plus(model, cls_fn, A_funcs)



    def svd_based_ddnm_plus(self, model, cls_fn, A_funcs):
        args, config = self.args, self.config

        dataset, test_dataset = get_dataset(args, config)
        device_count = torch.cuda.device_count()

        if args.subset_start >= 0 and args.subset_end > 0:
            assert args.subset_end > args.subset_start
            test_dataset = torch.utils.data.Subset(test_dataset, range(args.subset_start, args.subset_end))
        else:
            args.subset_start = 0
            args.subset_end = len(test_dataset)

        print(f'Dataset has size {len(test_dataset)}')

        def seed_worker(worker_id):
            worker_seed = args.seed % 2 ** 32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        g = torch.Generator()
        g.manual_seed(args.seed)
        val_loader = data.DataLoader(
            test_dataset,
            batch_size=config.sampling.batch_size,
            shuffle=True,
            num_workers=config.data.num_workers,
            worker_init_fn=seed_worker,
            generator=g,
        )

        deg = args.deg
        args.sigma_y = 2 * args.sigma_y #to account for scaling to [-1,1]
        sigma_y = args.sigma_y

        print(f'Start from {args.subset_start}')
        idx_init = args.subset_start
        idx_so_far = args.subset_start
        avg_psnr = 0.0
        pbar = tqdm.tqdm(val_loader)
        for x_orig, classes in pbar:
            x_orig = x_orig.to(self.device)
            x_orig = data_transform(self.config, x_orig)

            y = A_funcs.A(x_orig)

            b, hwc = y.size()
            if 'color' in deg:
                hw = hwc / 1
                h = w = int(hw ** 0.5)
                y = y.reshape((b, 1, h, w))
            elif 'inp' in deg or 'cs' in deg:
                pass
            else:
                hw = hwc / 3
                h = w = int(hw ** 0.5)
                y = y.reshape((b, 3, h, w))

            if self.args.add_noise: # for denoising test
                y = get_gaussian_noisy_img(y, sigma_y)

            y = y.reshape((b, hwc))

            Apy = A_funcs.A_pinv(y).view(y.shape[0], config.data.channels, self.config.data.image_size,
                                                self.config.data.image_size)

            if deg[:6] == 'deblur':
                Apy = y.view(y.shape[0], config.data.channels, self.config.data.image_size,
                                    self.config.data.image_size)
            elif deg == 'colorization':
                Apy = y.view(y.shape[0], 1, self.config.data.image_size, self.config.data.image_size).repeat(1,3,1,1)
            elif deg == 'inpainting':
                Apy += A_funcs.A_pinv(A_funcs.A(torch.ones_like(Apy))).reshape(*Apy.shape) - 1

            for i in range(len(Apy)):
                tvu.save_image(
                    inverse_data_transform_config(config, Apy[i]),
                    os.path.join(self.args.new_image_folder, f"degradation-{idx_so_far + i}.png")
                )
                tvu.save_image(
                    inverse_data_transform_config(config, x_orig[i]),
                    os.path.join(self.args.new_image_folder, f"gt-{idx_so_far + i}.png")
                )

            #Start DDIM
            x = torch.randn(
                y.shape[0],
                config.data.channels,
                config.data.image_size,
                config.data.image_size,
                device=self.device,
            )

            with torch.no_grad():
                if sigma_y==0.: # noise-free case, turn to ddnm
                    x, _ = ddnm_diffusion(x, model, self.betas, self.args.eta, A_funcs, y, args.device, cls_fn=cls_fn, classes=classes, config=config)
                else: # noisy case, turn to ddnm+
                    from functions.svd_ddnm import ddnm_plus_diffusion
                    x, _ = ddnm_plus_diffusion(x, model, self.betas, self.args.eta, A_funcs, y, sigma_y, args.device, cls_fn=cls_fn, classes=classes, config=config)

            x = [inverse_data_transform_config(config, xi) for xi in x]

            for j in range(x[0].size(0)):
                tvu.save_image(
                    x[0][j], os.path.join(self.args.new_image_folder, f"recon-{idx_so_far + j}_{0}.png")
                )
                orig = inverse_data_transform_config(config, x_orig[j])
                mse = torch.mean((x[0][j].to(self.device) - orig) ** 2)
                psnr = 10 * torch.log10(1 / mse)
                avg_psnr += psnr

            idx_so_far += y.shape[0]

            pbar.set_description("PSNR: %.2f" % (avg_psnr / (idx_so_far - idx_init)))

        avg_psnr = avg_psnr / (idx_so_far - idx_init)
        print("Total Average PSNR: %.2f" % avg_psnr)
        print("Number of samples: %d" % (idx_so_far - idx_init))


In [None]:
def get_schedule_jump(T_sampling, travel_length, travel_repeat):

    jumps = {}
    for j in range(0, T_sampling - travel_length, travel_length):
        jumps[j] = travel_repeat - 1

    t = T_sampling
    ts = []

    while t >= 1:
        t = t-1
        ts.append(t)

        if jumps.get(t, 0) > 0:
            jumps[t] = jumps[t] - 1
            for _ in range(travel_length):
                t = t + 1
                ts.append(t)

    ts.append(-1)

    _check_times(ts, -1, T_sampling)

    return ts

# 2.DDNM

In [None]:
def ddnm_diffusion(x, model, b, eta, A_funcs, y, device, cls_fn=None, classes=None, config=None):
    with torch.no_grad():
        # setup iteration variables
        skip = config.diffusion.num_diffusion_timesteps//config.time_travel.T_sampling
        n = x.size(0)
        x0_preds = []
        xs = [x]

        # generate time schedule
        times = get_schedule_jump(config.time_travel.T_sampling,
                               config.time_travel.travel_length,
                               config.time_travel.travel_repeat,
                              )
        time_pairs = list(zip(times[:-1], times[1:]))

        # reverse diffusion sampling
        for i, j in tqdm.tqdm(time_pairs):
            i, j = i*skip, j*skip
            if j<0:
              j=-1

            if j < i: # normal sampling
                t = (torch.ones(n) * i).to(x.device)
                next_t = (torch.ones(n) * j).to(x.device)
                at = compute_alpha(b, t.long())
                at_next = compute_alpha(b, next_t.long())

                xt = xs[-1].to(device)
                et = model(xt, t)

                if et.size(1) == 6:
                    et = et[:, :3]

                x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()

                x0_t_hat = x0_t - A_funcs.A_pinv(
                    A_funcs.A(x0_t.reshape(x0_t.size(0), -1)) - y.reshape(y.size(0), -1)
                ).reshape(*x0_t.size())

                c1 = (1 - at_next).sqrt() * eta
                c2 = (1 - at_next).sqrt() * ((1 - eta ** 2) ** 0.5)
                xt_next = at_next.sqrt() * x0_t_hat + c1 * torch.randn_like(x0_t) + c2 * et

                x0_preds.append(x0_t.to('cpu'))
                xs.append(xt_next.to('cpu'))

            else: # time-travel back
                next_t = (torch.ones(n) * j).to(x.device)
                at_next = compute_alpha(b, next_t.long())

                x0_t = x0_preds[-1].to(device)

                xt_next = at_next.sqrt() * x0_t + torch.randn_like(x0_t) * (1 - at_next).sqrt()

                xs.append(xt_next.to('cpu'))

    return [xs[-1]], [x0_preds[-1]]

## Get Degradation Operators

In [None]:
def get_degradation_matrix(deg, device, deg_scale=4):
    A_funcs = None

    if deg == 'inpainting':
        from functions.svd_operators import Inpainting
        loaded = np.load("exp/inp_masks/mask.npy")
        mask = torch.from_numpy(loaded).to(device).reshape(-1)
        missing_r = torch.nonzero(mask == 0).long().reshape(-1) * 3
        missing_g = missing_r + 1
        missing_b = missing_g + 1
        missing = torch.cat([missing_r, missing_g, missing_b], dim=0)
        A_funcs = Inpainting(3, 256, missing, device)
    elif deg == 'colorization':
        from functions.svd_operators import Colorization
        A_funcs = Colorization(256, device)
    elif deg == 'sr_averagepooling':
        blur_by = int(args.deg_scale)
        from functions.svd_operators import SuperResolution
        A_funcs = SuperResolution(3, 256, blur_by, device)
    elif deg == 'sr_bicubic':
        factor = int(args.deg_scale)
        from functions.svd_operators import SRConv
        def bicubic_kernel(x, a=-0.5):
            if abs(x) <= 1:
                return (a + 2) * abs(x) ** 3 - (a + 3) * abs(x) ** 2 + 1
            elif 1 < abs(x) and abs(x) < 2:
                return a * abs(x) ** 3 - 5 * a * abs(x) ** 2 + 8 * a * abs(x) - 4 * a
            else:
                return 0
        k = np.zeros((factor * 4))
        for i in range(factor * 4):
            x = (1 / factor) * (i - np.floor(factor * 4 / 2) + 0.5)
            k[i] = bicubic_kernel(x)
        k = k / np.sum(k)
        kernel = torch.from_numpy(k).float().to(device)
        A_funcs = SRConv(kernel / kernel.sum(), \
                         3, 256, device, stride=factor)
    elif deg == 'deblur_gauss':
        from functions.svd_operators import Deblurring
        sigma = 10
        pdf = lambda x: torch.exp(torch.Tensor([-0.5 * (x / sigma) ** 2]))
        kernel = torch.Tensor([pdf(-2), pdf(-1), pdf(0), pdf(1), pdf(2)]).to(device)
        A_funcs = Deblurring(kernel / kernel.sum(), 3, 256, device)
    else:
        raise ValueError("degradation type not supported")

    return A_funcs

# 3. Main

In [None]:
def colab_main(args):
    input_args, config = parse_args_and_config(args)

    try:
      runner = Diffusion(input_args, config)
      A_funcs = get_degradation_matrix(input_args.deg, input_args.device)
      runner.sample(A_funcs)
    except Exception:
      logging.error(traceback.format_exc())

    return 0

## Select Degradation types and Images

In [None]:
from easydict import EasyDict

args = EasyDict()
args.image_folder = '../results'
args.config = 'imagenet_256.yml'
args.path_y = 'imagenet'
args.simplified = False
args.seed = 1234
args.exp = 'exp'
args.verbose = 'info'

args.sigma_y = 0
args.eta = 0.85
args.cutoff = 1000
args.etaB = 1
args.noise_type = 'gaussian'
args.add_noise = False

# (Optional) Set sampling timesteps and Time travel steps
# Time travel was not covered in theory classes.
args.timesteps = 1000
args.T_sampling = 100
args.travel_length = 1
args.travel_repeat = 1


## "Edit this block as you want."

* Degradaton typs: sr_averagepooling, colorization, inpainting, sr_bicubic, deblur_gauss

In [None]:
'''Select Degradation types'''
args.deg = 'sr_averagepooling'
# args.deg = 'colorization'
# args.deg = 'inpainting'
# args.deg = 'sr_bicubic'
# args.deg = 'deblur_gauss'

# For Super-resolution, set degraded scale
args.deg_scale = 4 # 8, 16

'''
Select Images
Colab stops if all 8 images are inferred at once.
Do 1 or 2 at once.
DDNM/exp/datasets/imagenet/imagenet - 8 images exist
'''
args.subset_start = 0 # start idx >= 0
args.subset_end = 2 # start idx < end idx <= 8

In [None]:
colab_main(args)