# DDIM Inversion

Notebook written by **Jordan Lin** for CS 188 project.

`main.ipynb` is an alternative to `main.py` where I have more flexibility to experiment with my code.

In [None]:
train = False  # True if we are training, False otherwise
yaml_path = "./configs/celeba.yml"  # Config path for if train = True
log_path = "./logs/run_230226_160439"  # Model load path for if train = False

gpu_num = 0  # For multiple-GPU training

## Preliminaries

In [None]:
# External files edited elsewhere (e.g., PyCharm) are reloaded in Jupyter Notebooks
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [None]:
from IPython.display import HTML
from functools import partial

import yaml

import numpy as np
import torch
from torch import nn, optim
from torch.utils import data
import torchvision.utils as vutils
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

In [None]:
from networks.unet import UNet
from runners.diffusion import Diffusion
from evaluation.fid import FID

import inversion.optimization as oinv
import inversion.learning as linv
import inversion.hybrid as hinv
import inversion.interpolation as iinv

import utilities.data as dutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [None]:
device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")

In [None]:
def display_torch_image(image, norm=(0, 1)):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure(dpi=300)
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

In [None]:
def save_torch_image(image, path):
    image = np.clip(image.detach().cpu().moveaxis(0, -1).numpy() * 255, 0, 255).astype(np.uint8)
    Image.fromarray(image).save(path)

## Training & Loading

In [None]:
if train:
    config = utils.get_yaml(path=yaml_path)
else:
    config = utils.get_yaml(path=f"{log_path}/config.yml")

In [None]:
print(config)

In [None]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

In [None]:
%tensorboard --logdir=logs --port=8008 --load_fast=false --samples_per_plugin images=10000

In [None]:
if train:
    diffusion.train()
else:
    diffusion.load(path=log_path, name=f"network_{config.training.num_i}.pth", ema=False)
    diffusion.load(path=log_path, name=f"ema_{config.training.num_i}.pth", ema=True)

In [None]:
diffusion.freeze(ema=False)  # Freeze model layers to prevent OOM error during naive inversion
diffusion.freeze(ema=True)

## Sampling

Computing the FID score (or even just inception score) takes a very long time as the image generation process takes a while. Thus, currently the default number of sampled images tested is something around $4096$ images, which is not a lot, especially noting that we tend to get lower (i.e., better) FID scores with a larger number of sampled images. The standard is $50000$.

In [None]:
sample_generations = diffusion.log_grid(x="random", batch_size=64)
display_torch_image(sample_generations)

In [None]:
train_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="train",
                                   download=config.data.download)
train_loader = data.DataLoader(train_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

valid_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="valid",
                                   download=config.data.download)
valid_loader = data.DataLoader(valid_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

In [None]:
fid = FID(train_loader, valid_loader, config, device=device)

In [None]:
fid_train, fid_valid = fid(diffusion)

In [None]:
print(f"FID   |   Training: {fid_train.cpu().numpy():.7}   Validation: {fid_valid.cpu().numpy():.7}")

## Optimization Inversion

In [None]:
def find_from_dataset(path, loader):
    target = torch.from_numpy(np.asarray(Image.open(path)).astype(np.float32) / 255)[:, :, :3]
    target = target.moveaxis(-1, 0).unsqueeze(dim=0)
    target_find = None
    distance_min = torch.tensor(999999, dtype=torch.float32)
    for images, label in tqdm(loader):
        errors = (target - images).square().mean(dim=(1, 2, 3))
        min_i = errors.argmin(dim=0)
        if errors[min_i] < distance_min:
            distance_min = errors[min_i]
            target_find = images[min_i]
    return target_find

In [None]:
load_existing = False

In [None]:
if load_existing:
    test_image_1 = find_from_dataset("results/validation/celeba_validation_1.png", valid_loader)
    test_image_2 = find_from_dataset("results/validation/celeba_validation_2.png", valid_loader)
else:
    test_image_1 = next(iter(valid_loader))[0][0]
    test_image_2 = next(iter(valid_loader))[0][0]
display_torch_image(test_image_1)
display_torch_image(test_image_2)

In [None]:
z_1 = torch.randn(*test_image_1.shape, device=device)
z_2 = torch.randn(*test_image_2.shape, device=device)

### Interpolation

In [None]:
def save_row(images, path, indices=None):
    if indices is None:
        indices = list(range(len(image)))
    images = [images[i] for i in indices]
    images = torch.stack(images, dim=0)
    images_grid = vutils.make_grid(images, nrow=images.shape[0], padding=2, pad_value=1.0)
    if path is not None:
        save_torch_image(images_grid, path=path)
    display_torch_image(images_grid)

In [None]:
def save_torch_video(images, path, interval=50, scale=1, codec="h264"):
    images = [image.moveaxis(0, -1).numpy() for image in images]
    
    figure = plt.figure()
    axes = plt.Axes(figure, [0.0, 0.0, 1.0, 1.0])
    axes.set_axis_off()
    figure.add_axes(axes)
    figure.set_size_inches(images[0].shape[0] / 100 * scale, images[0].shape[1] / 100 * scale)
    
    frames = []
    for image in images:
        frames.append([axes.imshow(image, animated=True, aspect=1)])
        
    animation_ = animation.ArtistAnimation(figure, frames, interval=50)
    animation_.save(path, codec=codec)
    plt.show()

In [None]:
proj_fn_1 = partial(oinv.gradient_inversion, target=test_image_1, diffusion=diffusion,
                    optimizer="adam", lr=0.02, num_i=300, criterion="ssim", show_progress=True)
proj_fn_2 = partial(oinv.gradient_inversion, target=test_image_2, diffusion=diffusion,
                    optimizer="adam", lr=0.02, num_i=300, criterion="ssim", show_progress=True)
z_1_trained, x_1_reconstructed = proj_fn_1(z_1.clone(), sequence=True)
z_2_trained, x_2_reconstructed = proj_fn_2(z_2.clone(), sequence=True)

In [None]:
indices = [0, 1, 5, 10, 15, 25, 50, 100, 200, 300]
indices = [0] + [i + 1 for i in indices]

save_row([test_image_1] + x_1_reconstructed, path=None, indices=indices)
save_row([test_image_2] + x_2_reconstructed, path=None, indices=indices)

In [None]:
save_row((test_image_1, x_1_reconstructed[-1]), path="results/x_1_reconstructed.png")
save_row((test_image_2, x_2_reconstructed[-1]), path="results/x_2_reconstructed.png")

In [None]:
save_torch_video(x_1_reconstructed, path="results/x_1_reconstructed.webm", scale=1, codec="vp9")

In [None]:
save_torch_video(x_2_reconstructed, path="results/x_2_reconstructed.webm", scale=1, codec="vp9")

In [None]:
x_mixes = iinv.proj_interpolation(z_1_trained[-1].to(device), z_2_trained[-1].to(device),
                                  diffusion=diffusion, proj_fn_1=None, proj_fn_2=None,
                                  num_t_steps=10, num_alphas=150, show_progress=True)

In [None]:
save_row((x_1_reconstructed[-1], x_2_reconstructed[-1]), path="results/x_mixes.png")

In [None]:
save_torch_video(x_mixes, path="results/x_mixes.webm", scale=1, codec="vp9")

## Miscellaneous

In [None]:
from torch.utils import data
from torchvision import transforms, datasets

transform = transforms.Compose([transforms.CenterCrop((256, 256)), transforms.ToTensor()])

lsun_data = datasets.LSUN(root=config.data.root, classes=["church_outdoor_train"], transform=transform)
lsun_loader = data.DataLoader(lsun_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
lsun_image = next(iter(lsun_loader))[0][0]
print(lsun_image.shape)
display_torch_image(lsun_image)

In [None]:
i = 0
for image, label in tqdm(iter(lsun_loader)):
    i += image.shape[0]
print(i)

In [None]:
mini_data = dutils.get_dataset(name="miniplaces", shape=(128, 128),
                               root=config.data.root, split="train")
mini_loader = data.DataLoader(mini_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
mini_image = next(iter(mini_loader))[0][0]
print(mini_image.shape)
display_torch_image(mini_image)

In [None]:
i = 0
for image, label in tqdm(iter(mini_loader)):
    i += image.shape[0]
print(i)