# Model inference
This notebook allows you to load model checkpoints and use them to generate images.

In [None]:
import torch
from omegaconf import OmegaConf
from diffusion.model import NanoDiffusionModel
from diffusion.utils import CosineNoiseScheduler, DDIMSampler, decode_latents, get_available_device
from diffusers.models import AutoencoderKL
import matplotlib.pyplot as plt
from ipyfilechooser import FileChooser
from pathlib import Path

In [None]:
DEVICE = get_available_device()
LABEL_NAMES = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [None]:
fc = FileChooser(
    path='../models/',  # Start directory
    filename='',
    title='Select a model checkpoint:',
    # show_only_dirs=True  # Only show directories
)

display(fc)

In [None]:
checkpoint_path = fc.selected
checkpoint = torch.load(checkpoint_path, weights_only=False)

model_config = checkpoint["model_config"]
noise_scheduler_config = checkpoint["noise_scheduler_config"]
model = NanoDiffusionModel(model_config).to(DEVICE).eval()
model.load_state_dict(checkpoint["model_state_dict"])

vae = AutoencoderKL.from_pretrained(model_config.vae_name).to(DEVICE).eval()
noise_scheduler = CosineNoiseScheduler(noise_scheduler_config)

sampler = DDIMSampler(model, noise_scheduler, noise_scheduler_config.num_timesteps, 50)

## Sample all classes
Generate an image for each class of CIFAR-10 and display them.

In [None]:
noise = torch.randn(10, 16, 4, 4).to(DEVICE)
classes = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(-1).to(DEVICE)
latents = sampler.sample(noise, classes)

images = decode_latents(latents, vae)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for idx, (img, class_label) in enumerate(zip(images, classes)):
    axes[idx].imshow(img)
    axes[idx].set_title(f"{LABEL_NAMES[class_label]}", fontsize=14)
    axes[idx].axis('off')

plt.suptitle("Generated CIFAR-10 Samples", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

# Intermediate generation steps
This time, we convert all steps of the sampling process into the image space. This allows us to see how the model denoises the image over the different timesteps.

In [None]:
noise = torch.randn(1, 16, 4, 4).to(DEVICE)
classes = torch.tensor([0]).reshape(-1).to(DEVICE)
latent, intermediates = sampler.sample(noise, classes, return_intermediates=True)

images = decode_latents(latents, vae)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for idx, (img, class_label) in enumerate(zip(images, classes)):
    axes[idx].imshow(img)
    axes[idx].set_title(f"{LABEL_NAMES[class_label]}", fontsize=14)
    axes[idx].axis('off')

plt.suptitle("Generated CIFAR-10 Samples", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()