In [None]:
# So that external files (e.g., editing files in PyCharm) is reloaded in Jupyter Notebooks when they
# are updated.
%load_ext autoreload
%autoreload 2

%load_ext tensorboard

In [None]:
import yaml

import torch
from torch.utils import data
import matplotlib.pyplot as plt

In [None]:
from networks.unet import UNet
from runners.diffusion import Diffusion

import utilities.data as dutils
import utilities.ema as eutils
import utilities.math as mutils
import utilities.network as nutils
import utilities.runner as rutils
import utilities.utilities as utils

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def display_torch_image(image, norm=(0, 1)):
    if len(image.shape) == 4:
        image = image[0]
    if norm is None:
        norm = (image.min(), image.max())
    image = (image - norm[0]) / (norm[1] - norm[0])
    plt.figure()
    plt.axis("off")
    plt.imshow(image.moveaxis(-3, -1).detach().cpu().numpy(), vmin=0, vmax=1)

In [None]:
config = utils.get_yaml(path="./configs/celeba.yml")

In [None]:
diffusion = Diffusion(config, device=device)
print(f"Number of parameters: {diffusion.size()}")

In [None]:
%tensorboard --logdir=logs --load_fast=false --samples_per_plugin images=10000

In [None]:
diffusion.train()

In [None]:
torch.save(diffusion.network.state_dict(), f"logs/run_{diffusion.datetime}/model_63600.pth")

In [None]:
train_dataset = dutils.get_dataset(name=config.data.dataset, shape=config.data.shape,
                                   shape_original=config.data.shape_original, split="train",
                                   download=config.data.download)
train_loader = data.DataLoader(train_dataset, batch_size=config.training.batch_size,
                               shuffle=True, num_workers=config.data.num_workers)

In [None]:
n = 0
for image, labels in iter(train_loader):
    n += image.shape[0]
print(n)

In [None]:
test_image = next(iter(train_loader))[0][0]
display_torch_image(test_image)

test_generated = diffusion.sample(batch_size=config.training.batch_size, sequence=False)
display_torch_image(test_generated[0], norm=None)

In [None]:
test_generated = diffusion.sample(batch_size=config.training.batch_size, skip_type="uniform",
                                  sequence=True)

In [None]:
def min_max(tensor):
    print(f"Max: {tensor.max().cpu().numpy()}   Min: {tensor.min().cpu().numpy()}")