# DDIM Inversion

Notebook written by **Jordan Lin**.

`main.ipynb` is an alternative to `main.py` where I have more flexibility to experiment with my code.

In [1]:
train = False  # True if we are training, False otherwise
yaml_path = "./configs/celeba.yml"  # Config path for if train = True
run_path = "run_230311_015229"
log_path = f"./logs/{run_path}"  # Model load path for if train = False

gpu_num = 0  # For multiple-GPU training

## Preliminaries

In [2]:
# External files edited elsewhere (e.g., PyCharm) are reloaded in Jupyter Notebooks
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [3]:
from IPython.display import HTML
from functools import partial

import os
import yaml

import numpy as np
import torch
from torch import nn, optim
from torch.utils import data
import torch.nn.functional as F
import torchvision.utils as vutils
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.auto import tqdm

In [4]:
from networks.unet import UNet
from networks.resnet import ResNet
from runners.diffusion import Diffusion
from evaluation.fid import FID
from inversion.learning import NoiseEncoder

import inversion.optimization as oinv
import inversion.learning as linv
import inversion.hybrid as hinv
import inversion.interpolation as iinv

import editing.classification as eclass

import utilities.data as dutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [5]:
device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")

In [6]:
def display_torch_image(image, norm=(0, 1), dpi=72):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure(dpi=dpi)
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

In [7]:
def save_torch_image(image, path):
    image = np.clip(image.detach().cpu().moveaxis(0, -1).numpy() * 255, 0, 255).astype(np.uint8)
    Image.fromarray(image).save(path)

## Training & Loading

In [8]:
if train:
    config = utils.get_yaml(path=yaml_path)
else:
    config = utils.get_yaml(path=f"{log_path}/config.yml")

In [9]:
print(config)

Namespace(data=Namespace(dataset='celeba', root='~/.torch/datasets', shape=[3, 64, 64], shape_original=[3, 218, 178], num_train=162770, num_valid=19867, num_test=19962, random_flip=True, zero_center=True, clamp=True, flip_horizontal=0.5, flip_vertical=0.0, num_workers=4, download=True), network=Namespace(hidden_channels=32, num_blocks=2, channel_mults=[1, 2, 2, 2, 4], attention_sizes=[16], embed_channels=128, dropout=0.1, num_groups=8, ema=0.9995, do_conv_sample=True), diffusion=Namespace(beta_schedule='linear', beta_start=0.0001, beta_end=0.02, num_t=1000, num_t_steps=50, eta=0.0), training=Namespace(batch_size=64, log_batch_size=64, criterion='l1', num_i=72000, log_frequency=300, save_frequency=6000, tensorboard=True), evaluation=Namespace(batch_size=64, num_batches=64), optimizer=Namespace(name='adam', learning_rate=0.0002, weight_decay=0.0, beta_1=0.9, amsgrad=False, epsilon=1e-08, gradient_clip=1.0))


In [10]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

Number of parameters: 4935747


In [None]:
%tensorboard --logdir=logs --port=8008 --load_fast=false --samples_per_plugin images=10000

In [11]:
if train:
    diffusion.train()
else:
    diffusion.load(path=log_path, name=f"network_{config.training.num_i}.pth", ema=False)
    diffusion.load(path=log_path, name=f"ema_{config.training.num_i}.pth", ema=True)

In [12]:
diffusion.freeze(ema=False)  # Freeze model layers to prevent OOM error during naive inversion
diffusion.freeze(ema=True)

## Sampling

In [None]:
%%time

sample_generations = diffusion.log_grid(x="random", num_t_steps=10, batch_size=64)
display_torch_image(sample_generations, dpi=300)

In [13]:
train_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="train",
                                   download=config.data.download)
train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True,
                               num_workers=config.data.num_workers)

valid_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   root=config.data.root, split="valid",
                                   download=config.data.download)
valid_loader = data.DataLoader(valid_dataset, batch_size=128, shuffle=True,
                               num_workers=config.data.num_workers)

Files already downloaded and verified
Files already downloaded and verified


### CelebA Facial Attributes Classification

In [14]:
celeba_binary = ["5_o_clock_shadow", "arched_eyebrows", "attractive", "bags_under_eyes", "bald",
                 "bangs", "big_lips", "big_nose", "black_hair", "blond_hair", "blurry",
                 "brown_hair", "bushy_eyebrows", "chubby", "double_chin", "eyeglasses", "goatee",
                 "gray_hair", "heavy_makeup", "high_cheekbones", "male", "mouth_slightly_open",
                 "mustache", "narrow_eyes", "no_beard", "oval_face", "pale_skin", "pointy_nose",
                 "receding_hairline", "rosy_cheeks", "sideburns", "smiling", "straight_hair",
                 "wavy_hair", "wearing_earrings", "wearing_hat", "wearing_lipstick",
                 "wearing_necklace", "wearing_necktie", "young"]
celeba_targets = ["attractive", "eyeglasses", "male", "smiling", "young"]

In [15]:
celeba_train_counts = eclass.get_class_counts(train_loader, celeba_binary, None, ratio=True)
celeba_valid_counts = eclass.get_class_counts(valid_loader, celeba_binary, None, ratio=True)

  0%|          | 0/1272 [00:00<?, ?it/s]

  0%|          | 0/156 [00:00<?, ?it/s]

In [16]:
print(celeba_train_counts)

{'5_o_clock_shadow': 0.11167291269890028, 'arched_eyebrows': 0.2658843767278983, 'attractive': 0.5136265896664004, 'bags_under_eyes': 0.20446028137863242, 'bald': 0.022811328868956197, 'bangs': 0.15165571051176507, 'big_lips': 0.24091048719051422, 'big_nose': 0.23555323462554525, 'black_hair': 0.23902439024390243, 'blond_hair': 0.149087669718007, 'blurry': 0.05137310315168643, 'brown_hair': 0.20391964121152548, 'bushy_eyebrows': 0.1436751244086748, 'chubby': 0.05768261964735517, 'double_chin': 0.04651348528598636, 'eyeglasses': 0.0646372181605947, 'goatee': 0.06350678872028015, 'gray_hair': 0.04236652945874547, 'heavy_makeup': 0.38431529151563554, 'high_cheekbones': 0.45244823984763777, 'male': 0.4193708914419119, 'mouth_slightly_open': 0.4821895926767832, 'mustache': 0.040806045340050376, 'narrow_eyes': 0.11592431037660503, 'no_beard': 0.8341770596547275, 'oval_face': 0.28322786754315904, 'pale_skin': 0.04303618602936659, 'pointy_nose': 0.2755176015236223, 'receding_hairline': 0.08011

In [25]:
# This is basically the original CIFAR-10 ResNet-20 architecture
resnet = ResNet(in_shape=config.data.shape, num_classes=len(celeba_targets),
                filters=[[16, 16], [32, 32], [64, 64]], kernels=[[3, 3], [3, 3], [3, 3]],
                repeats=[3, 3, 3], in_kernel=5, in_stride=2,
                in_max_pool_kernel=1, in_max_pool_stride=1)
resnet.to(device)
print(f"Number of parameters: {utils.get_size(resnet)}")

Number of parameters: 273701


In [None]:
eclass.classification(resnet, train_loader, valid_loader, celeba_binary, celeba_targets,
                      train_weights=celeba_train_counts, valid_weights=celeba_valid_counts,
                      i_max=30000, i_print=(30000 // 20), device=device)

  0%|          | 0/30000 [00:00<?, ?it/s]

[1500]     Training: 88.357% (0.2591)   Validation: 87.487% (0.2694)
[3000]     Training: 87.472% (0.2729)   Validation: 86.922% (0.2834)
[4500]     Training: 90.008% (0.2252)   Validation: 89.309% (0.2368)


In [None]:
resnet.state_dict()

In [None]:
np.set_printoptions(suppress=True)

resnet.eval()
celeba_target_indices = eclass.get_class_indices(celeba_binary, celeba_targets)

with torch.no_grad():
    sample_test = next(iter(valid_loader))
    sample_outputs = torch.sigmoid(resnet(sample_test[0].to(device)))
    
display_torch_image(sample_test[0][1])
print(torch.stack((sample_test[1][0][celeba_target_indices],
                   sample_outputs[0].cpu()), dim=0).numpy())

### Sample Many

This can be useful for training encoders for inversion later.

In [None]:
@torch.no_grad()
def sample_many(diffusion, num_t_steps, batch_size, num_batches, save_frequency, save_noise=True):
    samples = []
    if save_noise:
        noises = []
    total = num_batches * batch_size
    
    os.mkdir(f"./samples/{run_path}")
        
    for i in tqdm(range(num_batches)):
        noise = torch.randn(batch_size, *config.data.shape)
        sample = diffusion.sample(x=noise.to(device), num_t_steps=num_t_steps, sequence=False)
        samples.append(sample.detach().cpu())
        if save_noise:
            noises.append(noise.detach().cpu())
        
        if (i + 1) % save_frequency == 0:
            num_sampled = (i + 1) * batch_size
            file_index = str(num_sampled).zfill(len(str(total)))
            
            samples = torch.cat(samples, dim=0)
            torch.save(samples, f"./samples/{run_path}/samples_{file_index}.pth")
            samples = []
            if save_noise:
                noises = torch.cat(noises, dim=0)
                torch.save(noises, f"./samples/{run_path}/noises_{file_index}.pth")
                noises = []

In [None]:
sample_many(diffusion, num_t_steps=10, batch_size=64, num_batches=4096, save_frequency=64,
            save_noise=True)

### Sample Dataset

In [None]:
sample_loader = dutils.get_loader_samples(batch_size=64, root=f"./samples/{run_path}",
                                          stop_iteration=True)  # Iterable custom loader

### FID

Computing the FID score (or even just inception score) takes a very long time as the image generation process takes a while. Thus, currently the default number of sampled images tested is something around $4096$ images, which is not a lot, especially noting that we tend to get lower (i.e., better) FID scores with a larger number of sampled images. The standard is $50000$.

In [None]:
fid = FID(train_loader, valid_loader, config, device=device)

In [None]:
fid_train, fid_valid = fid(diffusion, batch_size=100, num_batches=50)

In [None]:
print(f"FID   |   Training: {fid_train.cpu().numpy():.7}   "
      f"Validation: {fid_valid.cpu().numpy():.7}")

## Inversion

In [None]:
def save_row(images, path, indices=None):
    if indices is None:
        indices = list(range(len(images)))
    images = [images[i] for i in indices]
    images = torch.stack(images, dim=0)
    images_grid = vutils.make_grid(images, nrow=images.shape[0], padding=2, pad_value=1.0)
    if path is not None:
        save_torch_image(images_grid, path=path)
    display_torch_image(images_grid)

In [None]:
def save_torch_video(images, path, interval=50, scale=1, codec="h264"):
    images = [image.moveaxis(0, -1).numpy() for image in images]
    
    figure = plt.figure()
    axes = plt.Axes(figure, [0.0, 0.0, 1.0, 1.0])
    axes.set_axis_off()
    figure.add_axes(axes)
    figure.set_size_inches(images[0].shape[0] / 100 * scale, images[0].shape[1] / 100 * scale)
    
    frames = []
    for image in images:
        frames.append([axes.imshow(image, animated=True, aspect=1)])
        
    animation_ = animation.ArtistAnimation(figure, frames, interval=50)
    animation_.save(path, codec=codec)
    plt.show()

In [None]:
def load_image(path, resize=None):
    image = torch.from_numpy(np.asarray(Image.open(path)).astype(np.float32) / 255)[:, :, :3]
    image = image.moveaxis(-1, 0).unsqueeze(dim=0)
    if resize is not None:
        image = F.interpolate(image, resize, mode="area")
        # image = transforms.functional.resize(image, resize, transforms.InterpolationMode.BILINEAR)
    return image[0]

In [None]:
def find_from_dataset(path, loader):
    target = load_image(path)
    target_find = None
    distance_min = torch.tensor(999999, dtype=torch.float32)
    for images, label in tqdm(loader):
        errors = (target - images).square().mean(dim=(1, 2, 3))
        min_i = errors.argmin(dim=0)
        if errors[min_i] < distance_min:
            distance_min = errors[min_i]
            target_find = images[min_i]
    return target_find

### Optimization Inversion

In [None]:
load_existing = False

In [None]:
if type(load_existing) in (list, tuple):
    test_image_1 = load_image(load_existing[0], resize=(64, 64))
    test_image_2 = load_image(load_existing[1], resize=(64, 64))
elif load_existing:
    test_image_1 = find_from_dataset("results/validation/celeba_validation_1.png", valid_loader)
    test_image_2 = find_from_dataset("results/validation/celeba_validation_2.png", valid_loader)
else:
    test_image_1 = next(iter(valid_loader))[0][0]
    test_image_2 = next(iter(valid_loader))[0][0]
display_torch_image(test_image_1, dpi=72)
display_torch_image(test_image_2, dpi=72)

In [None]:
z_1 = torch.randn(*test_image_1.shape, device=device)
z_2 = torch.randn(*test_image_2.shape, device=device)

In [None]:
proj_fn_1 = partial(oinv.gradient_inversion, target=test_image_1, diffusion=diffusion,
                    optimizer="adam", lr=0.02, num_i=300, criterion="psnr", show_progress=True)
proj_fn_2 = partial(oinv.gradient_inversion, target=test_image_2, diffusion=diffusion,
                    optimizer="adam", lr=0.02, num_i=300, criterion="psnr", show_progress=True)
z_1_trained, x_1_reconstructed = proj_fn_1(z_1.clone(), sequence=True)
z_2_trained, x_2_reconstructed = proj_fn_2(z_2.clone(), sequence=True)

In [None]:
save_row((test_image_1, x_1_reconstructed[-1]), path="results/x_1_reconstructed.png")
save_row((test_image_2, x_2_reconstructed[-1]), path="results/x_2_reconstructed.png")

In [None]:
save_torch_video(x_1_reconstructed, path="results/x_1_reconstructed.webm", scale=1, codec="vp9")

In [None]:
save_torch_video(x_2_reconstructed, path="results/x_2_reconstructed.webm", scale=1, codec="vp9")

In [None]:
anime = False

if anime:
    indices = [0, 1, 5, 10, 15, 25, 50, 100, 200, 300]
    indices = [0] + [i + 1 for i in indices]

    save_row([test_image_1] + x_1_reconstructed, path="anime_inversion_1.png", indices=indices)
    save_row([test_image_2] + x_2_reconstructed, path="anime_inversion_2.png", indices=indices)

### Learning Inversion

In [None]:
encoder_args = {"hidden_channels": 16, "num_blocks": 2, "channel_mults": [1, 2, 2, 4],
                "attention_sizes": [], "time_embed_channels": None, "dropout": 0.1,
                "num_groups": 8, "do_conv_sample": True, "out_conv_zero": False}
encoder = NoiseEncoder(config, network_args=encoder_args, loss_type="reconstruction",
                       diffusion=diffusion, device=device)
print(f"Number of parameters: {encoder.size()}")

In [None]:
diffusion_args = {"num_t_steps": 10}
optimizer_args = {"name": "adam", "learning_rate": 0.0002, "weight_decay": 0.0, "beta_1": 0.9,
                  "amsgrad": False, "epsilon": 1e-7}
encoder.train(diffusion_args, optimizer_args, batch_size=8, num_i=6000,
              z_criterion="l2", x_criterion="psnr", loader=f"./samples/{run_path}")

## Interpolation

In [None]:
x_mixes = iinv.proj_interpolation(z_1_trained[-1].to(device), z_2_trained[-1].to(device),
                                  diffusion=diffusion, proj_fn_1=None, proj_fn_2=None,
                                  num_t_steps=10, num_alphas=150, show_progress=True)

In [None]:
save_row((x_1_reconstructed[-1], x_2_reconstructed[-1]), path="results/x_mixes.png")

In [None]:
save_row(x_mixes, indices=[0, 25, 49, 75, 99, 124, 149], path="results/x_mixes_.png")

In [None]:
save_torch_video(x_mixes, path="results/x_mixes.webm", scale=1, codec="vp9")

## Miscellaneous

In [None]:
from torch.utils import data
from torchvision import transforms, datasets

transform = transforms.Compose([transforms.CenterCrop((256, 256)), transforms.ToTensor()])

lsun_data = datasets.LSUN(root=config.data.root, classes=["church_outdoor_train"], transform=transform)
lsun_loader = data.DataLoader(lsun_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
lsun_image = next(iter(lsun_loader))[0][0]
print(lsun_image.shape)
display_torch_image(lsun_image)

In [None]:
i = 0
for image, label in tqdm(iter(lsun_loader)):
    i += image.shape[0]
print(i)

In [None]:
mini_data = dutils.get_dataset(name="miniplaces", shape=(128, 128),
                               root=config.data.root, split="train")
mini_loader = data.DataLoader(mini_data, batch_size=config.training.batch_size,
                              shuffle=True, num_workers=config.data.num_workers)

In [None]:
mini_image = next(iter(mini_loader))[0][0]
print(mini_image.shape)
display_torch_image(mini_image)

In [None]:
i = 0
for image, label in tqdm(iter(mini_loader)):
    i += image.shape[0]
print(i)