In [1]:
import pyrootutils
root = pyrootutils.setup_root(search_from=".", pythonpath=True)

import imageio
import torchvision
import hydra
import torch as T
from mltools.mltools.hydra_utils import reload_original_config
from mltools.mltools.diffusion import sample_heun, get_sigmas_karras

In [2]:
# Load a saved model
model_path = "/home/matthew/Documents/saved_networks/generation/network1"
orig_config = reload_original_config(model_path, ckpt_flag="best*")
dev = "cuda" if T.cuda.is_available() else "cpu"
model_class = hydra.utils.get_class(orig_config.model._target_)
model = model_class.load_from_checkpoint(orig_config.ckpt_path, map_location=dev)

In [12]:
# Create random noise to generate from
initial_noise = T.randn((1, *model.inpt_dim), device = model.device) * model.max_sigma

# Create the sigma function for the denoising
sigmas = get_sigmas_karras(model.min_sigma, model.max_sigma, 30)

# Run the sampler storing each output
outputs, stages = sample_heun(
    model=model,
    x=initial_noise,
    sigmas=sigmas,
    keep_all=True,
    disable=False,
)

# Undo the normalisation for each stage
for i in range(len(stages)):
    stages[i] = model.normaliser.reverse(stages[i])
    stages[i] = stages[i].squeeze(0).clip(0, 1)

100%|██████████| 29/29 [00:01<00:00, 15.35it/s]


In [25]:
# Turn the stages into a gif!
fn = torchvision.transforms.ToPILImage()
pil_images = [fn(img) for img in (stages+stages[::-1])]
imageio.mimsave("generation.gif", pil_images, loop =0)