# DDIM Inversion

Notebook written by **Jordan Lin** for CS 188 project.

`main.ipynb` is an alternative to `main.py` where I have more flexibility to experiment with my code.

In [None]:
train = False  # True if we are training, False otherwise
yaml_path = "./configs/flowers102.yml"  # Config path for if train = True
log_path = "./logs/run_230224_152425"  # Model load path for if train = False

gpu_num = 0  # For multiple-GPU training

## Preliminaries

In [None]:
# External files edited elsewhere (e.g., PyCharm) are reloaded in Jupyter Notebooks
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [None]:
import yaml

import torch
from torch import nn, optim
from torch.utils import data
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
from networks.unet import UNet
from runners.diffusion import Diffusion

import utilities.data as dutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [None]:
device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")

In [None]:
def display_torch_image(image, norm=(0, 1)):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure()
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

## Training & Loading

In [None]:
if train:
    config = utils.get_yaml(path=yaml_path)
else:
    config = utils.get_yaml(path=f"{log_path}/config.yml")

In [None]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

In [None]:
%tensorboard --logdir=logs --port=8008 --load_fast=false --samples_per_plugin images=10000

In [None]:
if train:
    diffusion.train()
else:
    diffusion.load(path=f"{log_path}/network_{config.training.num_i}.pth", ema=False)
    diffusion.load(path=f"{log_path}/ema_{config.training.num_i}.pth", ema=True)

In [None]:
diffusion.freeze(ema=False)  # Freeze model layers to prevent OOM error during naive inversion
diffusion.freeze(ema=True)

## Naive Inversion

In [None]:
def naive_inversion(diffusion, target, z=None, lr=0.01, num_t_steps=10, num_i=100,
                    show_init=False, show_result=False):
    if z is None:
        z = torch.randn(*target.shape, device=device)
    z.requires_grad_()
    
    if show_init:
        with torch.no_grad():
            y_0 = diffusion.sample(x=z.detach(), sequence=False,
                                   num_t_steps=num_t_steps)[0].detach()
        display_torch_image(y_0)
    
    optimizer = optim.Adam([z], lr=lr, betas=(0.9, 0.999), eps=1e-8)
    target = target.to(device)
    
    progress = tqdm(range(num_i), position=0)
    for i in progress:
        y = diffusion.sample(x=z, sequence=False, ema=True, num_t_steps=num_t_steps)[0]
        loss = (target - y).abs().mean()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        progress.set_description(f"Loss: {loss.detach().cpu().numpy()}")
        progress.refresh()
    
    if show_result:
        with torch.no_grad():
            y_t = diffusion.sample(x=z.detach(), sequence=False, num_t_steps=num_t_steps)[0].detach()
        display_torch_image(y_t)
    
    return z

In [None]:
valid_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="valid",
                                   download=config.data.download)
valid_loader = data.DataLoader(valid_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

In [None]:
test_image_1 = next(iter(valid_loader))[0][0]
test_image_2 = next(iter(valid_loader))[0][0]
display_torch_image(test_image_1)
display_torch_image(test_image_2)

In [None]:
def slerp(z1, z2, alpha):
    theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
    return torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1 \
           + torch.sin(alpha * theta) / torch.sin(theta) * z2

In [None]:
num_i = 1000
num_t_steps = 5

In [None]:
z_1 = torch.randn(*test_image_1.shape, device=device)
z_2 = torch.randn(*test_image_2.shape, device=device)

### Interpolation

In [None]:
for alpha in [0, 0.3, 0.5, 0.7, 1.0]:
    z_mix = slerp(z_1, z_2, alpha=alpha)
    test_mix = diffusion.sample(x=z_mix, sequence=False, num_t_steps=num_t_steps)[0].detach()
    display_torch_image(test_mix)

### Projected Interpolation

In [None]:
z_1_trained = naive_inversion(diffusion, test_image_1, z=z_1.clone(),
                              num_i=num_i, num_t_steps=num_t_steps)
z_2_trained = naive_inversion(diffusion, test_image_2, z=z_2.clone(),
                              num_i=num_i, num_t_steps=num_t_steps)

In [None]:
for alpha in [0.0, 0.3, 0.5, 0.7, 1.0]:
    z_mix = slerp(z_1_trained, z_2_trained, alpha=alpha)
    test_mix = diffusion.sample(x=z_mix, sequence=False, num_t_steps=num_t_steps)[0].detach()
    display_torch_image(test_mix)

## Miscellaneous

In [None]:
from torch.utils import data
from torchvision import transforms, datasets

transform = transforms.Compose([transforms.CenterCrop((256, 256)), transforms.ToTensor()])

lsun_data = datasets.LSUN(root=config.data.root, classes=["church_outdoor_train"], transform=transform)
lsun_loader = data.DataLoader(lsun_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
lsun_image = next(iter(lsun_loader))[0][0]
print(lsun_image.shape)
display_torch_image(lsun_image)

In [None]:
i = 0
for image, label in tqdm(iter(lsun_loader)):
    i += image.shape[0]
print(i)

In [None]:
mini_data = dutils.get_dataset(name="miniplaces", shape=(128, 128),
                               root=config.data.root, split="train")
mini_loader = data.DataLoader(mini_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
mini_image = next(iter(mini_loader))[0][0]
print(mini_image.shape)
display_torch_image(mini_image)

In [None]:
i = 0
for image, label in tqdm(iter(mini_loader)):
    i += image.shape[0]
print(i)