# Model inference
This notebook allows you to load model checkpoints and use them to generate images.

In [None]:
import torch
from diffusion.model import NanoDiffusionModel
from diffusion.utils import CosineNoiseScheduler, DDIMSampler, decode_latents, get_available_device
from diffusers.models import AutoencoderKL
import matplotlib.pyplot as plt
from ipyfilechooser import FileChooser

In [None]:
DEVICE = get_available_device()
LABEL_NAMES = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]

In [None]:
fc = FileChooser(
    path='../models/',  # Start directory
    filename='',
    title='Select a model checkpoint:',
    # show_only_dirs=True  # Only show directories
)

display(fc)

In [None]:
checkpoint_path = fc.selected
checkpoint = torch.load(checkpoint_path, weights_only=False)

model_config = checkpoint["model_config"]
noise_scheduler_config = checkpoint["noise_scheduler_config"]
model = NanoDiffusionModel(model_config).to(DEVICE).eval()
model.load_state_dict(checkpoint["model_state_dict"])

vae = AutoencoderKL.from_pretrained(model_config.vae_name).to(DEVICE).eval()
noise_scheduler = CosineNoiseScheduler(noise_scheduler_config)

sampler = DDIMSampler(model, noise_scheduler, noise_scheduler_config.num_timesteps, 50)

## Sample all classes
Generate an image for each class of CIFAR-10 and display them.

In [None]:
noise = torch.randn(10, 16, 4, 4).to(DEVICE)
labels = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(-1).to(DEVICE)
latents = sampler.sample(noise, labels)

images = decode_latents(latents, vae)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))
axes = axes.flatten()

for idx, (img, label) in enumerate(zip(images, labels)):
    axes[idx].imshow(img)
    axes[idx].set_title(f"{LABEL_NAMES[label]}", fontsize=14)
    axes[idx].axis('off')

plt.suptitle("Generated CIFAR-10 Samples", fontsize=16)
plt.tight_layout()
plt.show()

# Intermediate generation steps
This time, we convert all steps of the sampling process into the image space. This allows us to see how the model denoises the image over the different timesteps.

In [None]:
labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
timesteps = list(range(1000, -1, -100))

noise = torch.randn(10, 16, 4, 4).to(DEVICE)
labels = torch.tensor([labels]).reshape(-1).to(DEVICE)
final_latents, intermediates = sampler.sample(noise, labels, return_intermediates=True)

# Select 11 samples from each noise trajectory 
intermediates = [noise] + intermediates
latents = intermediates[::5] + [final_latents]
latents = torch.stack(latents)
latents_shape = latents.shape

# Get the samples in the right order for the plot
latents = latents.transpose(0, 1).reshape(-1, *latents.shape[2:])

images = decode_latents(latents, vae)

fig, axes = plt.subplots(10, 11, figsize=(12, 10))

for row in range(10):
    for col in range(11):
        idx = row * 11 + col
        axes[row, col].imshow(images[idx])

        if row == 0:
            axes[row, col].set_title(f"t={timesteps[col]}", fontsize=8, pad=2)

        if col == 0:
            axes[row, col].set_ylabel(LABEL_NAMES[row], fontsize=8, rotation=0, labelpad=35)
            # Hide ticks, otherwise it messes up the layout
            axes[row, col].set_xticks([])
            axes[row, col].set_yticks([])
            axes[row, col].spines['top'].set_visible(False)
            axes[row, col].spines['right'].set_visible(False)
            axes[row, col].spines['bottom'].set_visible(False)
            axes[row, col].spines['left'].set_visible(False)

        if col != 0:
              axes[row, col].axis('off')

plt.suptitle("Denoising trajectories", fontsize=16)
plt.tight_layout()
plt.show()