<font face='monospace'>

## <b>Denoising Diffusion Probabilistic Models - DDPM</b>

Importing the required modules

In [None]:
%pip install -qU fastai fastcore accelerate datasets torcheval diffusers ffmpeg-python

In [None]:
import os
import torch
import logging
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torchvision.transforms.functional as TF

from pathlib import Path
from torch.nn import init
from torch import nn,optim
from fastcore.all import *
from functools import partial
from diffusers import UNet2DModel
from fastcore.foundation import L
from datasets import load_dataset
from accelerate import Accelerator
from torch.optim import lr_scheduler
from IPython.display import display, HTML
from torch.utils.data import DataLoader,default_collate

from diffusion_ai import *

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
plt.rcParams['animation.writer'] = 'ffmpeg'
plt.rcParams['image.cmap'] = 'gray_r'
logging.disable(logging.WARNING)
torch.manual_seed(1)
set_seed(42)

### <font face='monospace'><b>Loading the dataset and preprocessing it.

In [None]:
# Constants
IMAGE_KEY = 'image'
LABEL_KEY = 'label'
DATASET_NAME = "fashion_mnist"
BATCH_SIZE = 128
LR = 4e-3
EPOCHS = 5

dataset = load_dataset(DATASET_NAME)

@inplace
def transformi(batch):
    """
    Transform function to resize and normalize images inplace in the batch.
    """
    batch[IMAGE_KEY] = [TF.resize(TF.to_tensor(image), (32, 32)) * 2 - 1 for image in batch[IMAGE_KEY]]


transformed_dataset = dataset.with_transform(transformi)
data_loaders = DataLoaders.from_dd(transformed_dataset, BATCH_SIZE, num_workers=4)
print("DataLoaders for Fashion MNIST dataset created successfully.")

In [None]:
dt = data_loaders.train
xb,yb = next(iter(dt))
show_images(xb[:16], imsize=0.5)

<font face='monospace'>

### <b>Training - easy with a callback!</b>
DDPM is trained quite simply in a few steps:
1. randomly select some timesteps in an iterative noising process (0, ... T).
2. Add noise (n) corresponding to this timestep to the original image (n ∝ t : 0 <= t<= T, n ∈ N(0, 1)). For increasing timesteps, the variance of the noise increases.
3. Pass in this noisy image and the timestep to our model
4. Model is trained with MSE loss between the model output and the amount of noise added to the image at the timestep.
<br>

We will implement this in a callback. After training, we need to sample from this model. This is an iterative denoising process starting from pure noise.

In [None]:
# Custom Callback to facilitate distributed training 
# with mixed precision using Accelerate.
class AccelerateCB(TrainCB):
    order = DeviceCB.order+10
    def __init__(self, n_inp=1, mixed_precision="fp16"):
        super().__init__(n_inp=n_inp)
        self.acc = Accelerator(mixed_precision=mixed_precision)

    def before_fit(self, learn):
        """
        Prepare the model, optimizer, and dataloaders for training.
        """
        learn.model,learn.opt,learn.dls.train,learn.dls.valid = self.acc.prepare(
            learn.model, learn.opt, learn.dls.train, learn.dls.valid)

    def backward(self, learn): 
        """
        Perform backpropagation with mixed precision.
        """
        self.acc.backward(learn.loss)

In [None]:
def noisify(x0, alpha_bar):
    """
    Add noise to the input images based on the given alpha_bar schedule.

    Args:
        x0 (tensor): The original clean images.
        alpha_bar (tensor): The cumulative product of alphas used for noising.

    Returns:
        tuple: A tuple containing the noised images and the corresponding time steps.
    """
    device = x0.device
    n = len(x0)  # returns the first dimension [b, c, h, w] -> b
    t = torch.randint(0, 1000, (n,), dtype=torch.long).to(device)
    epsilon = torch.randn(x0.shape, device=device)
    alpha_bar_t = alpha_bar[t].reshape(-1, 1, 1, 1).to(device)
    xt = alpha_bar_t.sqrt() * x0 + (1 - alpha_bar_t).sqrt() * epsilon
    return (xt, t.to(device)), epsilon

In [None]:
# Custom callback for Denoising Diffusion Probabilistic Models (DDPM)
class DDPMCB(Callback):
    """
    Callback to handle the training process for DDPMs.
    
    Attributes:
        order (int): Execution order of the callback relative to other callbacks.
        n_steps (int): Number of diffusion steps.
        beta_min (float): Minimum beta value for the diffusion process.
        beta_max (float): Maximum beta value for the diffusion process.
        beta (tensor): Linearly spaced beta values.
        alpha (tensor): Alpha values derived from beta.
        alpha_bar (tensor): Cumulative product of alpha values.
        sigma (tensor): Standard deviation values derived from beta.
    """
    order = DeviceCB.order + 1
    def __init__(self, n_steps, beta_min, beta_max):
        super().__init__()
        self.n_steps = n_steps
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.beta = torch.linspace(self.beta_min, self.beta_max, self.n_steps)
        self.alpha = 1.0 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.sigma = self.beta.sqrt()
        
    def before_batch(self, learn):
        """
        Apply noise to the batch before each training iteration.
        """
        learn.batch = noisify(learn.batch[0], self.alpha_bar)
    
    def sample(self, model, size):
        """
        Generate samples from the trained DDPM model.
        """
        return sample(model, size, self.alpha, self.alpha_bar, self.sigma, self.n_steps)


<font face='monospace'>

### <b>Sampling</b>
The bellow `sample` function is a custom function that is different from the conventional DDPM sampler because it skips most of the sampling steps. And does not affect the output generation much. 

   - The `sample_at` condition `(t + 101) % ((t + 101) // 100) == 0` allows the sampler to skip unnecessary calculations for most timesteps.
   - Instead of storing predictions at every timestep, the function only stores predictions at `sample_at` timesteps in the `preds` list.

In [None]:
@torch.no_grad()
def sample(model, sz, alpha, alphabar, sigma, n_steps=1000):
    """
    Generate samples from a trained DDPM model.
    It samples faster because of the `sample_at` condition.

    Args:
        model (torch.nn.Module): The trained DDPM model.
        sz (tuple): The size of the samples to generate.
        alpha (torch.Tensor): Alpha values derived from beta.
        alphabar (torch.Tensor): Cumulative product of alpha values.
        sigma (torch.Tensor): Standard deviation values derived from beta.
        n_steps (int): Number of diffusion steps.

    Returns:
        list: A list of generated samples at various timesteps.
    """
    device = next(model.parameters()).device
    x_t = torch.randn(sz, device=device)
    sample_at = {t for t in range(n_steps) if (t + 101) % ((t + 101) // 100) == 0}
    preds = []
    noise = None
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        z = torch.randn(x_t.shape, device=device) if t > 0 else torch.zeros(x_t.shape, device=device)

        alpha_t1 = alphabar[t-1] if t > 0 else torch.tensor(1.0).to(device)
        beta_bar_t = 1 - alphabar[t]
        beta_bar_t1 = 1 - alpha_t1

        # Predict noise
        if t in sample_at or noise is None:
            noise = model((x_t, t_batch))

        # Estimate the original clean image
        x_0_hat = ((x_t - beta_bar_t.sqrt() * noise) / alphabar[t].sqrt()).clamp(-1, 1)
        
        # Calculate coefficients for combining x_0_hat and x_t
        x0_coeff = alpha_t1.sqrt() * (1 - alpha[t]) / beta_bar_t
        xt_coeff = alpha[t].sqrt() * beta_bar_t1 / beta_bar_t
        
        # Update x_t for the next timestep
        x_t = x_0_hat * x0_coeff + x_t * xt_coeff + sigma[t] * z
        
        # Store intermediate results at specified timesteps
        if t in sample_at:
            preds.append(x_t.float().cpu())

    return preds

<font face='monospace'>**Let's use the predefined UNET model from diffusers library to predict the noise.**

In [None]:
# Initialize DDPM model

def init_ddpm(model):
    for o in model.down_blocks:
        for p in o.resnets:
            p.conv2.weight.data.zero_()
            for p in L(o.downsamplers): 
                init.orthogonal_(p.conv.weight)
    for o in model.up_blocks:
        for p in o.resnets: 
            p.conv2.weight.data.zero_()
    model.conv_out.weight.data.zero_()

In [None]:
class UNet(UNet2DModel):
    def forward(self, x):
      return super().forward(*x).sample

In [None]:
tmax = EPOCHS * len(data_loaders.train)
opt_func = partial(optim.Adam, eps=1e-5)
sched = partial(lr_scheduler.OneCycleLR, max_lr=LR, total_steps=tmax)
ddpm_cb = DDPMCB(n_steps=1000, beta_min=0.0001, beta_max=0.02)
model = UNet(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128), norm_num_groups=8)
init_ddpm(model)
cbs = [ddpm_cb,
        DeviceCB(), 
        ProgressCB(plot=True), 
        MetricsCB(), 
        BatchSchedCB(sched),
        AccelerateCB()]
learn = Learner(model, data_loaders, nn.MSELoss(), lr=LR, cbs=cbs, opt_func=opt_func)

In [None]:
learn.fit(EPOCHS)

---
<font face='monospace'>Let's sample from our model and see how the generated images are, also let's post-process them and display them

In [None]:
beta = torch.linspace(0.0001, 0.02, 1000)
alpha = 1.0 - beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

In [None]:
samples = sample(learn.model, (1, 1, 32, 32), alpha, alphabar, sigma, 1000)
s = samples[-1]*2 - 1
show_images(s[:16], figsize=(4,4), imsize=1.5)

In [None]:
# Save the trained model
model_path = Path('models')
model_path.mkdir(exist_ok=True)
torch.save(learn.model, model_path/'fashion_ddpm.pkl')

In [None]:
# Load the trained model
learn.model = torch.load(model_path/'fashion_ddpm.pkl')

In [None]:
%matplotlib auto

# Let's visualize the sampling process
def getImageFromList(x):
    return s[x][0]

fig = plt.figure(figsize=(3, 3))
ims = []
for i in range(len(s)):
    im = plt.imshow(getImageFromList(i), animated=True)
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
plt.close()

# Show the animation
HTML(ani.to_html5_video())

<font face='monospace'>
    
<b>NOTE:</b> Don't forget to install ffmpeg in your operating system. If you are using conda the try `conda install conda-forge::ffmpeg` in your terminal.

<font face='monospace'>

### **FID & KID**

`FID` - It calculates the distance between feature vectors calculated for real and generated images.

`KID` - It measures the squared Maximum Mean Discrepancy (MMD) between the Inception representations of the real and generated samples. MMD is a measure of the distance between two probability distributions. It's calculated using a kernel function, which is a measure of similarity between data points.

In [None]:
beta = torch.linspace(0.0001, 0.02, 1000)
alpha = 1.0 - beta
alphabar = alpha.cumprod(dim=0)
sigma = beta.sqrt()

In [None]:
# Load the trained diffusion model
smodel = torch.load('models/fashion_ddpm.pkl')

# Function to sample from the diffusion model
@torch.no_grad()
def sample(model, size, alpha, alphabar, sigma, n_steps):
    device = next(model.parameters()).device
    x_t = torch.randn(size, device=device)
    preds = []
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        z = torch.randn(x_t.shape).to(device) if t > 0 else torch.zeros(x_t.shape).to(device)
        alphabar_t1 = alphabar[t-1] if t > 0 else torch.tensor(1.0, device=device)
        bbar_t = 1 - alphabar[t]
        bbar_t1 = 1 - alphabar_t1
        x0_hat = (x_t - bbar_t.sqrt() * model((x_t, t_batch))) / alphabar[t].sqrt()
        x_t = x0_hat * alphabar_t1.sqrt() * (1 - alpha[t]) / bbar_t + x_t * alpha[t].sqrt() * bbar_t1 / bbar_t + sigma[t] * z
        preds.append(x0_hat.cpu())
    return preds

In [None]:
# Generate samples from the diffusion model
samples = sample(smodel, (128, 1, 32, 32), alpha, alphabar, sigma, 1000)
s = samples[-1] * 2 - 1
show_images(s[:16], imsize=1.5)

In [None]:
@inplace
def transformi2(batch):
    batch['image'] = [F.pad(TF.to_tensor(img), (2, 2, 2, 2)) * 2 - 1 for img in batch['image']]

tds = dataset.with_transform(transformi2)
dls = DataLoaders.from_dd(tds, BATCH_SIZE, num_workers=4)

# Load the pre-trained CNN model for evaluation
cmodel = torch.load('models/inference.pkl')
del cmodel[8]  # these are linear and probability layers which we don't need
del cmodel[7]

In [None]:
# Instantiate the ImageEval object which takes the mean and
# covariance of features at the last layer
# and calculate FID and KID
ie = ImageEval(cmodel, dls, cbs=[DeviceCB()])

In [None]:
fid_generated = ie.fid(s)
fid_original = ie.fid(xb * 2)

In [None]:
kid_generated = ie.kid(s)
kid_original = ie.kid(xb*2)

In [None]:
print(f"FID of generated images: {fid_generated}")
print(f"KID of generated images: {kid_generated}")
print()
print(f"FID of original images: {fid_original}")
print(f"KID of original images: {kid_original}")

In [None]:
import gc
gc.collect()

In [None]:
%reset -f