In [1]:
# So that external files (e.g., editing files in PyCharm) is reloaded in Jupyter Notebooks when they
# are updated.
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [2]:
import yaml

import torch
from torch import optim
from torch.utils import data
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [3]:
from networks.unet import UNet
from runners.diffusion import Diffusion

import utilities.data as dutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def display_torch_image(image, norm=(0, 1)):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure()
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

In [6]:
config = utils.get_yaml(path="./configs/celeba.yml")

In [7]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

Number of parameters: 5068323


In [8]:
%tensorboard --logdir=logs --load_fast=false --samples_per_plugin images=10000

In [None]:
diffusion.train()

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

  0%|          | 0/2544 [00:00<?, ?it/s]

In [None]:
torch.save(diffusion.network.state_dict(), f"logs/run_{diffusion.datetime}/model_63600.pth")

In [None]:
valid_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   shape_original=config.data.shape_original, split="valid",
                                   download=config.data.download)
valid_loader = data.DataLoader(train_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

In [None]:
test_image = next(iter(valid_loader))[0][0]
display_torch_image(test_image)

In [None]:
def naive_inversion(diffusion, target, z=None, num_i=1000):
    if z is None:
        z = torch.randn(*target.shape, device=device)
    z.requires_grad_()
    
    with torch.no_grad():
        y_0 = diffusion.sample(x=z.detach(), sequence=False)[0].detach()
    display_torch_image(y_0)
    
    optimizer = optim.Adam([z], lr=0.001, betas=(0.9, 0.999), eps=1e-8)
    
    for i in tqdm(range(num_i)):
        y = diffusion.sample(x=z, sequence=False)[0]
        loss = (target - y).square().sum()
        loss.backward()
    
    with torch.no_grad():
        y_t = diffusion.sample(x=z.detach(), sequence=False)[0].detach()
    display_torch_image(y_t)

In [None]:
naive_inversion(diffusion, test_image)

In [None]:
def min_max(tensor):
    print(f"Max: {tensor.max().cpu().numpy()}   Min: {tensor.min().cpu().numpy()}")