<font face='monospace'>

## <b>Denoising Diffusion Implicit Models - DDIM</b>

In this notebook we just see different variants of DDIM.

DDIMs are a variant of diffusion models without noise (σ = 0), and DDPMs have noise (σ = 1). Any σ between 0 and 1 is an interpolation between a DDIM and DDPM.

How can we add improvements to the DDPM/DDIM implementation? How about
- the removal of the concept of an integral number of steps, making the process more continuous.
- predicting the amount of noise in an image without passing the time step as input and modify the DDIM step to use the predicted alpha bar for each image.

<font face='monospace'>

1️⃣: Let's try implementing the above conditions into a model to obtain a variant of DDIM.

- Implicit Use of Noise Levels: The below code relies on continuous representation of noise levels (sigma) to add and remove noise, which is analogous to using continuous time steps without explicitly handling them.


- Noise Prediction: The model predicts the noise component directly from the noisy images using these noise levels. This approach abstracts away the explicit time steps by using the noise scale directly, allowing the model to work with continuous noise levels.

- This makes the sampling process more efficient.



In [None]:
%pip install -qU fastai fastcore datasets torcheval diffusers

In [None]:
import os
import torch
import logging
import fastcore.all as fc
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from glob import glob
from torch import optim
from pathlib import Path
from torch.nn import init
from torch import nn,tensor
from functools import partial
from datasets import load_dataset
from diffusers import UNet2DModel
from fastcore.foundation import L
from torch.optim import lr_scheduler
from fastprogress import progress_bar
from torch.utils.data import DataLoader,default_collate

from diffusion_ai import *

In [None]:
# Disable logging warnings
logging.disable(logging.WARNING)

# Set printing options and seed for reproducibility
set_seed(42)
torch.manual_seed(1)
torch.set_printoptions(precision=5, linewidth=140, sci_mode=False)
plt.rcParams['image.cmap'] = 'gray_r'
plt.rcParams['figure.dpi'] = 70

In [None]:
# Load Fashion MNIST dataset
n_steps = 1000
batch_size = 48
sz = (48,1,32,32)
name = "fashion_mnist"
dataset = load_dataset(name)

@inplace
def transformi(batch): 
    batch['image'] = [F.pad(TF.to_tensor(o), (2,2,2,2))*2-1 for o in batch['image']]

transformed_dataset  = dataset.with_transform(transformi)
data_loaders = DataLoaders.from_dd(transformed_dataset , batch_size, num_workers=4)

dl = data_loaders.train
xb,yb = b = next(iter(dl))

# Evaluate the generated samples using FID and KID
cmodel = torch.load('models/inference.pkl')
del(cmodel[8])
del(cmodel[7])

ie = ImageEval(cmodel, data_loaders, cbs=[DeviceCB()])

<font face='monospace'>
Standard deviation describes how dispersed a set of data is. Choosing the right `σ` value for model initialization or regularization helps to achieve the lowest possible loss during training.

In [None]:
# data_std = xb.std()
data_std = torch.tensor(0.66)  # standard deviation of our entire training dataset

In [None]:
# Calculate scaling coefficients for noise
def calculate_scalings(sigma):
    total_variance = sigma ** 2 + data_std ** 2
    c_skip = data_std ** 2 / total_variance
    c_out = sigma * data_std / total_variance.sqrt()
    c_in = 1 / total_variance.sqrt()
    return c_skip, c_out, c_in

In [None]:
c_skip, c_out, c_in = calculate_scalings(data_std)

In [None]:
# Function to add noise to images
def noisify(images):
    device = images.device
    sigma = (torch.randn([len(images)]) * 1.2 - 1.2).exp().to(images).reshape(-1, 1, 1, 1)
    # σ ~= 0.19 and we maintain this scale throughout to obtain
    # unit data distribution when scaling the data.
    noise = torch.randn_like(images, device=device)
    c_skip, c_out, c_in = calculate_scalings(sigma)
    noised_images = images + noise * sigma
    targets = (images - c_skip * noised_images) / c_out
    return (noised_images * c_in, sigma.squeeze()), targets

In [None]:
# Custom collate function for adding noise to input images.
def custom_collate(batch):
    return noisify(default_collate(batch)['image'])

# Create DataLoader with custom collate function
def create_dataloader(dataset):
    return DataLoader(dataset, batch_size=batch_size, collate_fn=custom_collate, num_workers=4)

In [None]:
dataloaders = DataLoaders(create_dataloader(transformed_dataset['train']), create_dataloader(transformed_dataset['test']))

In [None]:
dl = dataloaders.train
(noised_input,sig),target = b = next(iter(dl))

In [None]:
show_images(noised_input[:25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}'))

<font face='monospace'>
The above image looks noisy because it's preconditioned, we added a bit of noise and since we also obtain the target, we know how much noise is added. We use this information to improve the model. Also, note that we do not use `label` from the downloaded dataset, becauseare not using `CLIP`.

In [None]:
show_images(target[:25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}'))

<font face='monospace'>
we can see that target images which were clean are not clean in noised_input and vice versa.

In [None]:
# Function to denoise images
def denoise_images(target, noised_images, c_skip, c_out):
    return target * c_out + noised_images * c_skip

In [None]:
show_images(denoise_images(target, noised_input/c_in, c_skip, c_out)[:25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}'))

<font face='monospace'>
And the above is how our original images are without noise.

In [None]:
# Custom UNet model
class UNet(UNet2DModel):
    def forward(self, x):
        return super().forward(*x).sample

In [None]:
# Initialize DDPM model
def init_ddpm(model):
    for o in model.down_blocks:
        for p in o.resnets:
            p.conv2.weight.data.zero_()
            for p in fc.L(o.downsamplers): 
                init.orthogonal_(p.conv.weight)
    for o in model.up_blocks:
        for p in o.resnets: 
            p.conv2.weight.data.zero_()
    model.conv_out.weight.data.zero_()

In [None]:
# Training parameters
learning_rate = 1e-2
EPOCHS = 25
opt_function = partial(optim.Adam, eps=1e-5)
total_steps = EPOCHS * len(data_loaders.train)
scheduler = partial(lr_scheduler.OneCycleLR, max_lr=learning_rate, total_steps=total_steps)
callbacks = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(scheduler)]

# Create model, initialize it and create a Learner
model = UNet(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 256), norm_num_groups=8)
init_ddpm(model)
learn = Learner(model, dataloaders, nn.MSELoss(), lr=learning_rate, cbs=callbacks, opt_func=opt_function)

In [None]:
# Train the model
learn.fit(EPOCHS)

In [None]:
# This is the denoising model.
torch.save(learn.model, 'models/fashion_karras.pkl')
model = learn.model = torch.load('models/fashion_karras.pkl')

In [None]:
# # Perform denoising with the trained model
# with torch.no_grad():
#     sigma = sig.cuda().reshape(-1, 1, 1, 1)  # we use the sigma of the first batch
#     c_skip, c_out, c_in = calculate_scalings(sigma)
#     target_pred = learn.model((noised_input.cuda(), sigma.cuda()))
#     x0_pred = denoise_images(target_pred, noised_input.cuda() / c_in, c_skip, c_out)

with torch.no_grad():
    sigma = sig.reshape(-1, 1, 1, 1)  # we use the sigma of the first batch
    c_skip, c_out, c_in = calculate_scalings(sigma)
    target_pred = learn.model((noised_input, sig))
    x0_pred = denoise_images(target_pred, noised_input / c_in, c_skip, c_out)

In [None]:
show_images(noised_input[:25], imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}'))

In [None]:
show_images(x0_pred[:25].clamp(-1,1), imsize=1.5, titles=fc.map_ex(sig[:25], '{:.02f}'))

<font face='monospace'>
The above is a model that predicts the amount of noise to be removed.

---

<font face='monospace'>

2️⃣: `σ` formulation in the above process was as shown below:
```python
σ = []*1.2 - 1.2; # batch_size times
σ = σ.exp()

```
---

Now let's try a different formulation that is used in "karras"

$$
\sigma(n, \sigma_{min}, \sigma_{max}, \rho) = \left( \sigma_{max} + \text{ramp}(0, 1, n) \times (\sigma_{min} - \sigma_{max}) \right)^\rho
$$

Where:
- $n$ is the number of sigma values to generate.
- $\rho$ controls the non-linear transition between $ \sigma_{max} $ and $ \sigma_{min} $.
- $ \text{ramp}(0, 1, n) $ creates a sequence of values that linearly interpolate between 0 and 1 over $ n $ steps.

The final sequence of sigma values will have an additional zero appended at the end.



In [None]:
# Generate sigmas using Karras noise scheduling.
def sigmas_karras(n, sigma_min=0.01, sigma_max=80., rho=7., device='cpu'):
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min**(1/rho)
    max_inv_rho = sigma_max**(1/rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho-max_inv_rho))**rho
    return torch.cat([sigmas, tensor([0.])]).to(device)

In [None]:
# Function to denoise images
def denoise_images(model, x, sig):
    c_skip,c_out,c_in = calculate_scalings(sig)
    return model((x*c_in, sig))*c_out + x*c_skip

<font face='monospace'>

### <b>Euler Sampler</b>
The main idea is to follow a path of increasing data fidelity while reducing noise, which is guided by a learned model.

In the context of diffusion models, the Euler method is adapted to update the noisy image $x$ using the denoising model's predictions. 

Equation:

$x_{i+1} = x_i + \left(\frac{x_i - \text{denoised}}{\sigma_i}\right) (\sigma_{i+1} - \sigma_i)$

Where:
- $x_i$ is the current noisy image.
- $\text{denoised}$ is the model's prediction of the denoised image.
- $\sigma_i$ is the current noise level.
- $\sigma_{i+1}$ is the next noise level.



In [None]:
# Euler sampler for updating the noisy image.
@torch.no_grad()
def euler_sample(x, sigs, i, model):
    sig,sig2 = sigs[i],sigs[i+1]
    denoised = denoise_images(model, x, sig)
    return x + (x-denoised)/sig*(sig2-sig)

In [None]:
# Generate samples using the euler sampler with karras noising.
def sample(sampler, model, steps=100, sigma_max=80., **kwargs):
    preds = []
    x = torch.randn(sz).to(model.device)*sigma_max
    sigs = sigmas_karras(steps, device=model.device, sigma_max=sigma_max)
    for i in progress_bar(range(len(sigs)-1)):
        x = sampler(x, sigs, i, model, **kwargs)
        preds.append(x)
    return preds

In [None]:
preds = sample(euler_sample, model, steps=100)

In [None]:
s = preds[-1]
s.min(),s.max()

In [None]:
show_images(s[:25].clamp(-1,1), imsize=1.5)

In [None]:
# euler 100
ie.fid(s),ie.kid(s),s.shape

In [None]:
# reals
ie.fid(xb)

In [None]:
import gc
gc.collect()

In [None]:
%reset -f

<font face='monospace'>

We can use many more different types of samplers like `Heun`, `Euler Ancestral`, `Linear Multistep Coefficient`, etc. All mostly produce similar results; The denoising/update step that they use is a bit different. That's all!