<font face='monospace'>

## <b>Denoising Diffusion Implicit Models - DDIM</b>

What we are implementing here is an unconditional model; we are not performing class conditioning in this notebook, which will be addressed in another notebook.

In [None]:
%pip install -qU fastai fastcore datasets torcheval diffusers

In [None]:
import os
import torch
import logging
import matplotlib as mpl
import fastcore.all as fc
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from diffusion_ai import *
from torch.nn import init
from torch import nn,optim
from functools import partial
from diffusers import UNet2DModel
from fastcore.foundation import L
from types import SimpleNamespace
from datasets import load_dataset
from torch.optim import lr_scheduler
from fastprogress.fastprogress import progress_bar
from torch.utils.data import DataLoader,default_collate

In [None]:
set_seed(42)
torch.manual_seed(1)
logging.disable(logging.WARNING)
mpl.rcParams['figure.dpi'] = 70
mpl.rcParams['image.cmap'] = 'gray_r'
torch.set_printoptions(precision=4, linewidth=140, sci_mode=False)

### <font face='monospace'><b>Loading the dataset and preprocessing it.

In [None]:
xl,yl = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)

In [None]:
# Set batch size
bs = 256

@inplace
def transformi(batch):
    # Resize and normalize images in the batch.
    batch[xl] = [F.pad(TF.to_tensor(img), (2, 2, 2, 2)) * 2 - 1 for img in batch[xl]]

# Apply transformations to the dataset
transformed_ds = dsd.with_transform(transformi)
dls = DataLoaders.from_dd(transformed_ds, bs, num_workers=4)

In [None]:
dt = dls.train
xb,yb = next(iter(dt))

In [None]:
# Load previously trained diffusion model
class UNet(UNet2DModel):
    def forward(self, x):
        return super().forward(*x).sample

In [None]:
# Initialize the model for FashionMNIST
model = UNet(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 128), norm_num_groups=8)
model = torch.load('models/fashion_ddpm.pkl')

In [None]:
# Load inference model for FID, KID
inference_model = torch.load('models/inference.pkl')
del inference_model[8]
del inference_model[7]

In [None]:
image_eval = ImageEval(inference_model, dls, cbs=[DeviceCB()])

<font face='monospace'>

### <b>Implementing DDIM</b>

The main process which differentiates DDPM and DDIM is the sampling process, which removes noise.

In the context of the DDIM scheduler, `eta` is a parameter that controls the weight of the noise added in each diffusion step.

The value of `eta` can influence the amount of noise added at each step and therefore the overall quality of the generated samples. A higher `eta` will result in more noise being added, which could potentially lead to more diverse but less accurate samples. Conversely, a lower `eta` will result in less noise being added, which could lead to more accurate but less diverse samples.


In [None]:
def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000):
    beta = torch.linspace(betamin, betamax, n_steps)
    return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt())

sc = linear_sched(betamax=0.01)
abar = sc.abar
sig = sc.sig

In [None]:
def ddim_step(x_t, t, noise, abar_t, abar_t1, bbar_t, bbar_t1, eta):
    # Perform a single DDIM step.
    vari = ((bbar_t1 / bbar_t) * (1 - abar_t / abar_t1))
    sig = vari.sqrt() * eta
    x_0_hat = ((x_t - bbar_t.sqrt() * noise) / abar_t.sqrt())
    x_t = abar_t1.sqrt() * x_0_hat + (bbar_t1 - sig**2).sqrt() * noise
    if t > 0:
        x_t += sig * torch.randn(x_t.shape).to(x_t)  # Add random noise
    return x_t

In [None]:
@torch.no_grad()
def sample(f, model, size, num_steps, skip_steps=1, eta=1.):
    # Generate samples using the DDIM scheduler
    timesteps = list(reversed(range(0, num_steps, skip_steps)))
    x_t = torch.randn(size).to(model.device)
    preds = []
    for i, t in enumerate(progress_bar(timesteps)):
        abar_t1 = abar[timesteps[i + 1]] if t > 0 else torch.tensor(1)
        noise = model((x_t, t))
        x_t = f(x_t, t, noise, abar[t], abar_t1, 1 - abar[t], 1 - abar_t1, eta)
        preds.append(x_t.float().cpu())
    return preds

In [None]:
# Define the size of the samples and generate them
sample_size = (16, 1, 32, 32)
samples = sample(ddim_step, model, sample_size, 1000, 10)

In [None]:
# Scale and show the images
scaled_samples = (samples[-1] * 2)#.clamp(-1, 1)
show_images(scaled_samples[:25], imsize=1.5)

<font face='monospace'>
Calculate FID, KID scores.

In [None]:
image_eval.fid(scaled_samples),image_eval.kid(scaled_samples)

In [None]:
image_eval.fid(xb),image_eval.kid(xb)

In [None]:
clean_mem() # Free up some memory

<font face='monospace'>

That's it. Now if we don't want a pre-trained model, we can instead train another model using the below code and try the above steps again. Just see how fast DDIM works compared to DDPM.

---

<font face='monospace'>

### Compelete architecture for reimplementing the pretrained model

In [None]:
# Define a linear schedule for DDPM
def linear_schedule(beta_min=0.0001, beta_max=0.02, num_steps=1000):
    beta = torch.linspace(beta_min, beta_max, num_steps)
    return SimpleNamespace(alpha=1.-beta, alpha_bar=(1.-beta).cumprod(dim=0), sigma=beta.sqrt())

schedule = linear_schedule(beta_max=0.01)
alpha_bar = schedule.alpha_bar
alpha = schedule.alpha
sigma = schedule.sigma

In [None]:
# Function to add noise to images
def noisify(images, alpha_bar):
    device = images.device
    batch_size = len(images)
    time_steps = torch.randint(0, 1000, (batch_size,), dtype=torch.long)
    noise = torch.randn(images.shape, device=device)
    alpha_bar_t = alpha_bar[time_steps].reshape(-1, 1, 1, 1).to(device)
    noisy_images = alpha_bar_t.sqrt() * images + (1 - alpha_bar_t).sqrt() * noise
    return (noisy_images, time_steps.to(device)), noise

In [None]:
# Custom UNet model for DDPM
class UNet(UNet2DModel):
    def forward(self, x):
        return super().forward(*x).sample

# Initialize DDPM model
def init_ddpm(model):
    for o in model.down_blocks:
        for p in o.resnets:
            p.conv2.weight.data.zero_()
            for p in fc.L(o.downsamplers): 
                init.orthogonal_(p.conv.weight)
    for o in model.up_blocks:
        for p in o.resnets: 
            p.conv2.weight.data.zero_()
    model.conv_out.weight.data.zero_()

# Collate function for DDPM
def collate_ddpm(batch):
    return noisify(default_collate(batch)[xl], alpha_bar)

# Create dataloaders for DDPM
def create_dataloader(dataset):
    return DataLoader(dataset, batch_size=bs, collate_fn=collate_ddpm, num_workers=4)

In [None]:
# Create the data loader
dls = DataLoaders(create_dataloader(transformed_ds['train']), create_dataloader(transformed_ds['test']))

In [None]:
# Initialize the model for FashionMNIST
model = UNet(in_channels=1, out_channels=1, block_out_channels=(32, 64, 128, 256), norm_num_groups=8)
print(sum(p.numel() for p in model.parameters()))
init_ddpm(model)

In [None]:
# Training configuration
LR = 1e-3
EPOCHS = 1
opt_func = partial(optim.AdamW, eps=1e-5)
total_steps = EPOCHS * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=LR, total_steps=total_steps)
callbacks = [DeviceCB(), MixedPrecision(), ProgressCB(plot=True), MetricsCB(), BatchSchedCB(sched)]

# Note: MixedPrecision() callback uses GradScaler which needs GPU! or else it might crash

# Create model
learn = Learner(model, dls, nn.MSELoss(), lr=LR, cbs=callbacks, opt_func=opt_func)

In [None]:
learn.fit(EPOCHS)

In [None]:
# DDPM sampler
@torch.no_grad()
def sample_ddpm(model, size):
    parameters = next(model.parameters())
    generated_images = torch.randn(size).to(parameters)
    predictions = []
    for t in reversed(range(1000)):
        time_batch = torch.full((generated_images.shape[0],), t, device=parameters.device, dtype=torch.long)
        noise = (torch.randn(generated_images.shape) if t > 0 else torch.zeros(generated_images.shape)).to(parameters)
        alpha_bar_t1 = alpha_bar[t - 1] if t > 0 else torch.tensor(1)
        beta_bar_t = 1 - alpha_bar[t]
        beta_bar_t1 = 1 - alpha_bar_t1
        predicted_noise = model((generated_images, time_batch))
        x0_hat = ((generated_images - beta_bar_t.sqrt() * predicted_noise) / alpha_bar[t].sqrt())
        generated_images = x0_hat * alpha_bar_t1.sqrt() * (1 - alpha[t]) / beta_bar_t + generated_images * alpha[t].sqrt() * beta_bar_t1 / beta_bar_t + sigma[t] * noise
        predictions.append(generated_images.float().cpu())
    return predictions

In [None]:
# Sample images using DDPM
samples = sample_ddpm(model, (3, 1, 32, 32))
scaled_samples = (samples[-1] + 0.5).clamp(0, 1)
show_images(scaled_samples[:16], imsize=1.5)

<font face='monospace'>
    
**DDPM ↑**

---

In [None]:
def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000):
    beta = torch.linspace(betamin, betamax, n_steps)
    return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt())

# Initialize scheduler
n_steps = 1000
sc = linear_sched(betamax=0.01)
abar = sc.abar
sig = sc.sig

In [None]:
# Define DDIM step function
def ddim_step(x_t, noise, alpha_bar_t, alpha_bar_t1, beta_bar_t, beta_bar_t1, eta, sigma):
    sigma = ((beta_bar_t1 / beta_bar_t).sqrt() * (1 - alpha_bar_t / alpha_bar_t1).sqrt()) * eta
    x0_hat = ((x_t - (1 - alpha_bar_t).sqrt() * noise) / alpha_bar_t.sqrt()).clamp(-1.5, 1.5)
    sigma = torch.max(sigma, torch.tensor(0.0)) # Set to zero if very small or NaN
    x_t = alpha_bar_t1.sqrt() * x0_hat + (beta_bar_t1 - sigma**2).sqrt() * noise
    x_t += sigma * torch.randn(x_t.shape).to(x_t)
    return x_t

In [None]:
# Define sampling function
@torch.no_grad()
def sample_ddim(f, model, size, num_steps, skip_steps=1, eta=1.):
    # Generate samples using the DDIM scheduler
    timesteps = list(reversed(range(0, num_steps, skip_steps)))
    x_t = torch.randn(size).to(model.device)
    preds = []
    for i, t in enumerate(progress_bar(timesteps)):
        abar_t1 = abar[timesteps[i + 1]] if t > 0 else torch.tensor(1)
        noise = model((x_t, t))
        x_t = f(x_t, t, noise, abar[t], abar_t1, 1 - abar[t], 1 - abar_t1, eta)
        preds.append(x_t.float().cpu())
    return preds

In [None]:
# Sample images using DDIM
sample_size = (256, 1, 32, 32)
ddim_predictions = sample_ddim(ddim_step, model, sample_size, 100, eta=1.)
s = (ddim_predictions[-1] * 2)  # Scale outputs to have range between -1 and 1

<font face='monospace'>
    
**DDIM ↑**

---

In [None]:
show_images(s[:16], imsize=1.5)

In [None]:
image_eval.fid(s),image_eval.kid(s),s.shape

In [None]:
# Try for different number of steps.

preds = sample_ddim(ddim_step, model, sample_size, steps=50, eta=1.)
image_eval.fid(preds[-1]*2)

In [None]:
clean_mem() # Free up some memory