In [None]:
import torch
from omegaconf import OmegaConf
from diffusion.model import NanoDiffusionModel
from diffusion.utils import CosineNoiseScheduler, DDIMSampler, decode_latents, get_available_device
from diffusers.models import AutoencoderKL
import matplotlib.pyplot as plt

In [None]:
CONFIG_FROM_CHECKPOINT = False

checkpoint_path = "../models/checkpoint_epoch_0999.pt"
config_path = "../config/config.yaml"

device = get_available_device()

In [None]:
checkpoint = torch.load(checkpoint_path, map_location=device)

if CONFIG_FROM_CHECKPOINT:
    cfg = checkpoint.model_conifg
else:
    cfg = OmegaConf.load(config_path)
    OmegaConf.resolve(cfg)

In [None]:
model = NanoDiffusionModel(cfg.model).to(device).eval()
model.load_state_dict(checkpoint["model_state_dict"])

vae = AutoencoderKL.from_pretrained(cfg.model.vae_name).to(device).eval()
noise_scheduler = CosineNoiseScheduler(cfg.noise_scheduler)

sampler = DDIMSampler(model, noise_scheduler, cfg.num_timesteps, 50)

In [None]:
noise = torch.randn(10, 16, 4, 4).to(device)
context = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(-1).to(device)
latents = sampler.sample(noise, context)


In [None]:
images = decode_latents(latents, vae)

for img in images:
    plt.figure()
    plt.imshow(img)
    plt.axis('off')
    plt.show()