# DDIM Inversion

Notebook written by **Jordan Lin** for CS 188 project.

`main.ipynb` is an alternative to `main.py` where I have more flexibility to experiment with my code.

In [None]:
train = False  # True if we are training, False otherwise
yaml_path = "./configs/flowers102.yml"  # Config path for if train = True
log_path = "./logs/run_230226_031528"  # Model load path for if train = False

gpu_num = 0  # For multiple-GPU training

## Preliminaries

In [None]:
# External files edited elsewhere (e.g., PyCharm) are reloaded in Jupyter Notebooks
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [None]:
from IPython.display import HTML
from functools import partial

import yaml

import torch
from torch import nn, optim
from torch.utils import data
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm

In [None]:
from networks.unet import UNet
from runners.diffusion import Diffusion

import inversions.optimization as oinv
import inversions.learning as linv
import inversions.hybrid as hinv
import inversions.interpolation as iinv

import utilities.data as dutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [None]:
device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")

In [None]:
def display_torch_image(image, norm=(0, 1)):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure()
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

## Training & Loading

In [None]:
if train:
    config = utils.get_yaml(path=yaml_path)
else:
    config = utils.get_yaml(path=f"{log_path}/config.yml")

In [None]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

In [None]:
%tensorboard --logdir=logs --port=8008 --load_fast=false --samples_per_plugin images=10000

In [None]:
if train:
    diffusion.train()
else:
    diffusion.load(path=f"{log_path}/network_{config.training.num_i}.pth", ema=False)
    diffusion.load(path=f"{log_path}/ema_{config.training.num_i}.pth", ema=True)

In [None]:
diffusion.freeze(ema=False)  # Freeze model layers to prevent OOM error during naive inversion
diffusion.freeze(ema=True)

## Optimization Inversion

In [None]:
valid_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="valid",
                                   download=config.data.download)
valid_loader = data.DataLoader(valid_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

In [None]:
test_image_1 = next(iter(valid_loader))[0][0]
test_image_2 = next(iter(valid_loader))[0][0]
display_torch_image(test_image_1)
display_torch_image(test_image_2)

In [None]:
z_1 = torch.randn(*test_image_1.shape, device=device)
z_2 = torch.randn(*test_image_2.shape, device=device)

### Interpolation

In [None]:
def save_torch_video(images, path, interval=50, scale=1):
    images = [image.moveaxis(0, -1).numpy() for image in images]
    
    figure = plt.figure()
    axes = plt.Axes(figure, [0.0, 0.0, 1.0, 1.0])
    axes.set_axis_off()
    figure.add_axes(axes)
    figure.set_size_inches(images[0].shape[0] / 100 * scale, images[0].shape[1] / 100 * scale)
    
    frames = []
    for image in images:
        frames.append([axes.imshow(image, animated=True, aspect=1)])
        
    animation_ = animation.ArtistAnimation(figure, frames, interval=50)
    animation_.save(path)
    plt.show()

In [None]:
x_mixes = iinv.proj_interpolation(z_1.clone(), z_2.clone(), diffusion=diffusion,
                                  proj_fn_1=None, proj_fn_2=None,
                                  num_t_steps=10, num_alphas=100, show_progress=True)

In [None]:
save_torch_video(x_mixes, path="results/test4.mp4", scale=5)

In [None]:
proj_fn_1 = partial(oinv.gradient_inversion, target=test_image_1, diffusion=diffusion, lr=0.02,
                    num_i=300, criterion="psnr", show_progress=True)
proj_fn_2 = partial(oinv.gradient_inversion, target=test_image_2, diffusion=diffusion, lr=0.02,
                    num_i=300, criterion="psnr", show_progress=True)

x_mixes = iinv.proj_interpolation(z_1.clone(), z_2.clone(), diffusion=diffusion,
                                  proj_fn_1=proj_fn_1, proj_fn_2=proj_fn_2,
                                  num_t_steps=10, num_alphas=100, show_progress=True)

In [None]:
save_torch_video(x_mixes, path="results/test5.mp4", scale=5)

## Miscellaneous

In [None]:
from torch.utils import data
from torchvision import transforms, datasets

transform = transforms.Compose([transforms.CenterCrop((256, 256)), transforms.ToTensor()])

lsun_data = datasets.LSUN(root=config.data.root, classes=["church_outdoor_train"], transform=transform)
lsun_loader = data.DataLoader(lsun_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
lsun_image = next(iter(lsun_loader))[0][0]
print(lsun_image.shape)
display_torch_image(lsun_image)

In [None]:
i = 0
for image, label in tqdm(iter(lsun_loader)):
    i += image.shape[0]
print(i)

In [None]:
mini_data = dutils.get_dataset(name="miniplaces", shape=(128, 128),
                               root=config.data.root, split="train")
mini_loader = data.DataLoader(mini_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
mini_image = next(iter(mini_loader))[0][0]
print(mini_image.shape)
display_torch_image(mini_image)

In [None]:
i = 0
for image, label in tqdm(iter(mini_loader)):
    i += image.shape[0]
print(i)