# Conditional DDPM (MNIST)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import seed_everything

from diffusion import MNISTDataModule, DDPM2d

In [None]:
_ = seed_everything(111111)  # set random seeds manually

## MNIST data

In [None]:
mnist = MNISTDataModule(
    data_dir='../run/data/',
    mean=None,
    std=None,
    random_state=42,
    batch_size=32
)

mnist.prepare_data()  # download data if not yet done
mnist.setup(stage='test')  # create test set

In [None]:
test_loader = mnist.test_dataloader()
x_batch, y_batch = next(iter(test_loader))
image_shape = x_batch.shape[1:]

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx, 0].numpy()
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## DDPM import

In [None]:
ckpt_file = '../run/mnist_cond/version_0/checkpoints/last.ckpt'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

ddpm = DDPM2d.load_from_checkpoint(ckpt_file)

ddpm = ddpm.eval()
ddpm = ddpm.to(device)

## Generative process

In [None]:
cids = torch.tensor([1, 2, 3, 4, 5])  # set targets

x_noise = torch.randn(len(cids), *image_shape)
x_denoise = ddpm.denoise_all_steps(x_noise.to(device), cids=cids).cpu()

In [None]:
plot_steps = [0, 20, 50, 100, 200, 500, 1000]

reverse_plot_steps = [ddpm.num_steps - s for s in reversed(plot_steps)]

fig, axes = plt.subplots(nrows=len(cids), ncols=len(reverse_plot_steps), figsize=(9, 8))
for row_idx in range(len(axes)):
    for time_idx, ax in zip(reverse_plot_steps, axes[row_idx]):
        image = x_denoise[time_idx, row_idx, 0].numpy()
        ax.imshow(image, cmap='gray')
        ax.set_title('{} steps'.format(time_idx))
        ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

In [None]:
num_repeats = 5
cids = torch.arange(10).repeat_interleave(num_repeats)  # set targets

x_gen = ddpm.generate(
    sample_shape=image_shape,
    cids=cids,
    num_samples=len(cids)
).cpu()

In [None]:
fig, axes = plt.subplots(nrows=num_repeats, ncols=10, figsize=(10, num_repeats))
for idx, ax in enumerate(axes.T.ravel()):
    image = x_gen[idx, 0].numpy()
    ax.imshow(image, cmap='gray', vmin=0, vmax=1)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()