# Denoising diffusion: MNIST

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from diffusion import (
    DDPM2d,
    UNet,
    make_beta_schedule
)

## MNIST data

In [None]:
data_path = '../data'

val_set = datasets.MNIST(data_path, train=False, transform=transforms.ToTensor(), download=True)

print('No. val. images:', len(val_set))

In [None]:
batch_size = 32

val_loader = DataLoader(val_set,
                        batch_size=batch_size,
                        drop_last=False,
                        shuffle=False,
                        num_workers=4,
                        pin_memory=True)

print('No. val. batches:', len(val_loader))

In [None]:
x_batch, y_batch = next(iter(val_loader))
print('Images shape:', x_batch.shape)
print('Labels shape:', y_batch.shape)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx, 0].numpy()
    label = y_batch[idx].item()
    ax.imshow(image, cmap='gray')
    ax.set_title(val_set.classes[label])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## DDPM

In [None]:
ckpt_file = '../mnist/version_0/checkpoints/last.ckpt'
checkpoint = torch.load(ckpt_file)

In [None]:
# ddpm = DDPM2d.load_from_checkpoint(ckpt_file, **checkpoint['hyper_parameters'])
ddpm = DDPM2d(**checkpoint['hyper_parameters'])
ddpm.load_state_dict(checkpoint['state_dict'])

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))

ax1.plot(np.arange(len(ddpm.betas)) + 1, ddpm.betas)
ax1.set(xlim=(0, len(ddpm.betas)), ylim=(0, ddpm.betas.max()))
ax1.set(xlabel='t', ylabel='$\\beta$')
ax1.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax1.set_axisbelow(True)

ax2.plot(np.arange(len(ddpm.alphas_bar)) + 1, ddpm.alphas_bar)
ax2.set(xlim=(0, len(ddpm.alphas_bar)), ylim=(0, 1))
ax2.set(xlabel='t', ylabel='$\\bar{\\alpha}$')
ax2.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax2.set_axisbelow(True)

fig.tight_layout()

## Forward process simulation

In [None]:
x_noisy = ddpm.diffuse_all_steps(x_batch)

In [None]:
plot_steps = [0, 50, 100, 200, 500, 1000]

sample_idx = np.random.randint(x_noisy.shape[1]) # select random sample from batch

fig, axes = plt.subplots(nrows=1, ncols=len(plot_steps), figsize=(8, 2))
for time_idx, ax in zip(plot_steps, axes.ravel()):
    image = x_noisy[time_idx, sample_idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set_title('{} steps'.format(time_idx))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Generation

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
ddpm = ddpm.to(device)

In [None]:
ddpm.eval()
x_denoise = ddpm.denoise_all_steps(torch.randn(16, 1, 28, 28).to(device)).cpu()

In [None]:
plot_steps_reverse = [ddpm.num_steps - s for s in reversed(plot_steps)]

sample_idx = np.random.randint(x_denoise.shape[1]) # select random sample from batch

fig, axes = plt.subplots(nrows=1, ncols=len(plot_steps_reverse), figsize=(8, 2))
for time_idx, ax in zip(plot_steps_reverse, axes.ravel()):
    image = x_denoise[time_idx, sample_idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set_title('{} steps'.format(time_idx))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

In [None]:
x_gen = ddpm.generate(sample_shape=(1, 28, 28), num_samples=16).cpu()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    image = x_gen[idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()