# Denoising diffusion: MNIST

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import pathlib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import pytorch_lightning as pl

from diffusion import (
    DDPM,
    UNet,
    make_beta_schedule
)

## MNIST data

In [None]:
transform = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.ToTensor()
])

In [None]:
data_path = pathlib.Path.home() / 'Data'

train_set = datasets.MNIST(data_path, train=True, transform=transform, download=True)
val_set = datasets.MNIST(data_path, train=False, transform=transforms.ToTensor(), download=True)

print('No. train images:', len(train_set))
print('No. test images:', len(val_set))

In [None]:
batch_size = 32

train_loader = DataLoader(train_set,
                          batch_size=batch_size,
                          drop_last=True,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)

val_loader = DataLoader(val_set,
                        batch_size=batch_size,
                        drop_last=False,
                        shuffle=False,
                        num_workers=4,
                        pin_memory=True)

print('No. train batches:', len(train_loader))
print('No. val. batches:', len(val_loader))

In [None]:
x_batch, y_batch = next(iter(train_loader))
print('Images shape:', x_batch.shape)
print('Labels shape:', y_batch.shape)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    image = x_batch[idx, 0].numpy()
    label = y_batch[idx].item()
    ax.imshow(image, cmap='gray')
    ax.set_title(train_set.classes[label])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## DDPM

In [None]:
eps_model = UNet.from_params(in_channels=1,
                             mid_channels=[8, 16, 32],
                             kernel_size=3,
                             padding=1,
                             norm='batch',
                             activation='relu',                        
                             embed_dim=100,
                             num_resblocks=3,
                             upsample_mode='bilinear_conv')

In [None]:
# betas = make_beta_schedule(num_steps=1000, mode='quadratic', beta_range=(1e-04, 0.02))
# betas = make_beta_schedule(num_steps=1000, mode='cosine', cosine_s=0.008)
betas = make_beta_schedule(num_steps=1000, mode='sigmoid', sigmoid_range=(-5, 5))

ddpm = DDPM(eps_model=eps_model, betas=betas, criterion='mse')

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))

ax1.plot(np.arange(len(ddpm.betas)) + 1, ddpm.betas)
ax1.set(xlim=(0, len(ddpm.betas)), ylim=(0, ddpm.betas.max()))
ax1.set(xlabel='t', ylabel='$\\beta$')
ax1.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax1.set_axisbelow(True)

ax2.plot(np.arange(len(ddpm.alphas_bar)) + 1, ddpm.alphas_bar)
ax2.set(xlim=(0, len(ddpm.alphas_bar)), ylim=(0, 1))
ax2.set(xlabel='t', ylabel='$\\bar{\\alpha}$')
ax2.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax2.set_axisbelow(True)

fig.tight_layout()

## Forward process simulation

In [None]:
x_noisy = ddpm.diffuse_all_steps(x_batch)

In [None]:
plot_steps = [0, 50, 100, 200, 500, 1000]

sample_idx = np.random.randint(x_noisy.shape[1]) # select random sample from batch

fig, axes = plt.subplots(nrows=1, ncols=len(plot_steps), figsize=(8, 2))
for time_idx, ax in zip(plot_steps, axes.ravel()):
    image = x_noisy[time_idx, sample_idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set_title('{} steps'.format(time_idx))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Reverse process training

In [None]:
logger = pl.loggers.CSVLogger('.', name='lightning_logs', version=0)

trainer = pl.Trainer(logger=logger,
                     accelerator='gpu' if torch.cuda.is_available() else 'cpu',
                     devices=1,
                     max_epochs=1000,
                     log_every_n_steps=len(train_loader),
                     enable_progress_bar=True)

trainer.validate(ddpm, dataloaders=val_loader, verbose=False) # check validation loss before training
trainer.fit(ddpm, train_dataloaders=train_loader, val_dataloaders=val_loader) # start training

In [None]:
metrics_df = pd.read_csv('lightning_logs/version_0/metrics.csv')

train_df = metrics_df[['step', 'train_loss']].dropna()
val_df = metrics_df[['step', 'val_loss']].dropna()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(train_df['step'], train_df['train_loss'], alpha=0.7, label='train')
ax.plot(val_df['step'], val_df['val_loss'], alpha=0.7, label='val')
ax.set(xlabel='step', ylabel='loss')
ax.set_xlim([0, max(train_df['step'].max(), val_df['step'].max())])
ax.legend()
ax.grid(visible=True, which='both', color='gray', alpha=0.2, linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Generation

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
ddpm = ddpm.to(device)

In [None]:
ddpm.eval()
x_denoise = ddpm.denoise_all_steps(torch.randn(16, 1, 28, 28).to(device)).cpu()

In [None]:
plot_steps_reverse = [ddpm.num_steps - s for s in reversed(plot_steps)]

sample_idx = np.random.randint(x_denoise.shape[1]) # select random sample from batch

fig, axes = plt.subplots(nrows=1, ncols=len(plot_steps_reverse), figsize=(8, 2))
for time_idx, ax in zip(plot_steps_reverse, axes.ravel()):
    image = x_denoise[time_idx, sample_idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set_title('{} steps'.format(time_idx))
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

In [None]:
x_gen = ddpm.generate(sample_shape=(1, 28, 28), num_samples=16).cpu()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(5, 3))
for idx, ax in enumerate(axes.ravel()):
    image = x_gen[idx, 0].numpy()
    ax.imshow(image, cmap='gray')
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()