# Diffusion models

In [None]:
!wget --quiet --show-progress -O utils.py "https://raw.githubusercontent.com/daniil-shlenskii/isp-2024-introduction-to-generative-diffusion-models/main/coding/utils.py"

In [None]:
import os
import re
import random

import numpy as np
import seaborn as sns
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from sklearn.utils import shuffle
from tqdm import tqdm
from copy import deepcopy
from typing import Optional, Tuple, List

from utils import get_labeled_data_loader, MyUNet


MODELS_DIR = "models"
if not os.path.exists(MODELS_DIR):
    os.mkdir(MODELS_DIR)
SAVE_DPPM_SWISS_PATH = f"{MODELS_DIR}/ddpm_swiss.pt"
SAVE_DPPM_CFG_SWISS_PATH = f"{MODELS_DIR}/ddpm_cfg_swiss.pt"
SAVE_DPPM_MNIST_PATH = f"{MODELS_DIR}/ddpm_mnist.pt"
SAVE_DPPM_CFG_MNIST_PATH = f"{MODELS_DIR}/ddpm_cfg_mnist.pt"

## DDPM

In this part you have to implement your own diffusion model (DDPM) and apply it to SwissRoll dataset.

In [None]:
def make_swiss_dataset(num_samples):
    X0, _ = make_swiss_roll(num_samples // 2, noise=0.3, random_state=0)
    X1, _ = make_swiss_roll(num_samples // 2, noise=0.3, random_state=0)
    X0 = X0[:, [0, 2]]
    X1 = X1[:, [0, 2]]
    X1 = -X1
    X, y = shuffle(
        np.concatenate([X0, X1], axis=0),
        np.concatenate([np.zeros(len(X0)), np.ones(len(X1))], axis=0),
        random_state=0)
    X = (X - X.mean(axis=0)) / X.std(axis=0)

    return X, y

X, y = make_swiss_dataset(2000)

In [None]:
plt.figure(figsize=(4, 4))
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y);

**Quick recap of diffusion models theory**

Diffusion model consists of forward and backward processes.

Forward process is defined as a posterior distribution $q(x_{1:T}|x_0)$.

It is also a Markov chain, which consequently add gaussian noise to a given object $x_0$.

On every step noise is added with a different magnitude, which is determined with a schedule of variances $\{\beta_1, ... \beta_T\}$.

If this schedule is chosen properly and T goes to infinity (or is large enough), we are to converge to pure noise $\mathcal{N}(0, I)$.

Distributions $q$ have the following view:
$$
 q(x_t | x_{t - 1}) := \mathcal{N}(x_t; \sqrt{1 - \beta_t}x_{t - 1}, \beta_tI), \ \ \ \ \ \ \ q(x_{1:T}|x_0) = \prod_{t = 1}^T q(x_t | x_{t - 1})
$$


Now let's take a look of a backward process.

Backward process consequently denoise pure gaussian noise until the object from the original distribution is gotten.

So a diffusion model is a probability model with latent variables
$p_\theta(x_0) := \int p_\theta(x_{0:T}) dx_{1:T}$,
where latents $x_1, ..., x_T$ correspond to noised objects and $x_0$ is an object from an original distribution.

Joint distribution $p_\theta(x_{0:T})$ is called the backward diffusion process, which is essentially a Markov chain of gaussian distributions
$p_\theta(x_{i-1}|x_{i})$:

$$
p(x_{0:T}) = p(x_T) \prod_{t = 1}^Tp_{\theta}(x_{t-1}|x_t) \ \ \ \ \ \ \ \ \ p_\theta(x_{T})=\mathcal{N}(x_T | 0, I)
$$
$$
  p_{\theta}(x_{t - 1}|x_t):= \mathcal{N}(x_{t - 1}; \mu_{\theta}(x_t, t), \Sigma_{\theta}(x_t, t))
$$



Let's back to the distribution $q(x_t | x_{t - 1})$.

In order to get $x_t$ we have to compute $x_1, ..., x_{t - 1}$ iteratively.

Though due to properties of the gaussian distriubtion it can be done more efficiently.

Let's denote
$\alpha_t := 1- \beta_t$ и $\bar{\alpha}_t:= \prod_{i = 1}^t\alpha_i$.

Then
$$
q(x_t | x_0) = \mathcal{N}(x_t;\sqrt{\bar{\alpha}_t} x_0, (1-\bar{\alpha}_t)I) \quad \quad \quad \quad \quad \quad (1)
$$

So a model can be trained then by optimizing the individual terms of the sum of the variational lower bound:
$$
L_{VLB} = \mathbb{E}_q [\underbrace{D_\text{KL}(q(\mathbf{x}_T |
\mathbf{x}_0) \parallel p_\theta(\mathbf{x}_T))}_{L_T} + \sum_{t=2}^T
\underbrace{D_\text{KL}(q(\mathbf{x}_{t-1} | \mathbf{x}_t,
\mathbf{x}_0) \parallel p_\theta(\mathbf{x}_{t-1}
| \mathbf{x}_t))}_{L_{t-1}} \underbrace{- \log p_\theta(\mathbf{x}_0
| \mathbf{x}_1)}_{L_0}
$$

For training you just have to write down the following distribution $q(\mathbf{x}_{t-1} | \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta}_t \mathbf{I}) $:

$$
    \boldsymbol{\mu}(\mathbf{x}_t, \mathbf{x}_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \mathbf{x}_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1 - \bar{\alpha}_t} \mathbf{x}_0 \ \ \ \ \ \ (2)
$$
$$
    \tilde{\beta}_t = \frac{1 - \bar{\alpha}_{t-1}}{1 - \bar{\alpha}_t} \cdot \beta_t  \quad \quad \quad \quad \quad \quad \quad (3)
$$


Follow the link to find details [Denoising Diffusion Probabilistic Models (Ho et al. 2020)](https://arxiv.org/abs/2006.11239).

Nevertheless, in this paper was shown, that training with a simpler loss, you'll get better results.

Recall that
$$
x_t(x_0, \epsilon) = \sqrt{\bar{\alpha}_t} x_0 +  \sqrt{(1-\bar{\alpha}_t)}\epsilon, \ \ \ \epsilon \sim \mathcal{N}(0, I) \quad \quad \quad \quad \quad \quad \quad (4)
$$

Let our model predict $\epsilon$ from equality above, training by optimizing a following loss:

$$L^{simple}_t = \mathbb{E}_{x_0, \epsilon, t}\bigg[ \|\epsilon - \epsilon_{\theta}(x_t, t)\|^2\bigg]$$

This loss will be used in this task.

In order to sample (backward process), we have to get
$\mu_{\theta}(x_t, x_0)$ from $\epsilon_{\theta}(x_t, t)$.

To do that find $\hat{x}_0(\epsilon_{\theta}, x_t)$ from the eq (4) and
substitute it to eq (2).

_____

**Now to the task**

In [None]:
# it's just an utility function. basically, returns arr[timesteps], where timesteps are indices. (look at class Diffusion)
def _extract_into_tensor(arr: th.Tensor, timesteps: th.Tensor, broadcast_shape: Tuple):
    """
    Extract values from a 1-D torch tensor for a batch of indices.
    :param arr: 1-D torch tensor.
    :param timesteps: a tensor of indices to extract from arr.
    :param broadcast_shape: a larger shape of K dimensions with the batch
                            dimension equal to the length of timesteps.
    :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
    """
    res = arr.to(device=timesteps.device)[timesteps].float()
    while len(res.shape) < len(broadcast_shape):
        res = res[..., None]
    return res.expand(broadcast_shape)

### DDPM class

It consists of
- ForwardDiffusion class
- BackwardDiffusion clas
- Model predicting noise

You are to fill in the gaps marked with `your code`

In [None]:
def get_beta_schedule(num_diffusion_timesteps: int) -> th.Tensor:
    scale = 1000 / num_diffusion_timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
    betas = th.from_numpy(betas).double()
    return betas

In [None]:
class BaseDiffusion:
    def __init__(self, betas: th.Tensor) -> None:
        self.betas = betas
        self.alphas = 1 - self.betas
        self.alphas_cumprod = th.cumprod(self.alphas, dim=-1)
        self.num_timesteps = len(self.betas)

basediff = BaseDiffusion(get_beta_schedule(20))
basediff.alphas_cumprod

In [None]:
class ForwardDiffusion(BaseDiffusion):
    def q_mean_variance(self, x0: th.Tensor, t: th.Tensor) -> th.Tensor:
        # ====
        # your code
        # calculate mean and variance of the distribution q(x_t | x_0) (use equation (1))
        ...
        # ====
        return mean, variance

    def q_sample(self, x0: th.Tensor, t: th.Tensor, noise: Optional[th.Tensor]=None) -> th.Tensor:
        # ====
        # your code
        # sample from the distribution q(x_t | x_0) (use equation (1))
        ...
        # ====
        return samples

Let's take a look on how our data is noised with $t$ increasing.

In [None]:
T = 100
forward_diffusion = ForwardDiffusion(get_beta_schedule(T))

plot_n_steps = 8
noise_n_steps = 3

_, axs = plt.subplots(1, plot_n_steps, figsize=(plot_n_steps * 4, 4))
t_to_plot = list(np.round(np.linspace(forward_diffusion.num_timesteps - 1, 15, num=noise_n_steps)).astype("int")) + \
            list(np.round(np.linspace(10, 0, num=plot_n_steps - noise_n_steps)).astype("int"))
t_to_plot = t_to_plot[::-1]
for i,t in enumerate(t_to_plot):
    x = forward_diffusion.q_sample(
        x0=th.from_numpy(X),
        t=th.ones_like(th.from_numpy(y)).long() * t,
    )
    sns.scatterplot(x=x[:,0], y=x[:,1], hue=y, ax=axs[i])
    axs[i].set(title=t)

**How to understand if we took T large enough?**

In [None]:
class ReverseDiffusion(BaseDiffusion):
    def __init__(self, *args, clip_x0: Optional[bool]=False, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.alphas_cumprod_prev = th.cat(
            [th.tensor([1.0], device=self.betas.device), self.alphas_cumprod[:-1]], dim=0
        )

        # ====
        # your code
        # calculate variance of the distribution q(x_{t-1} | x_t, x_0) mean (use equation (3))
        self.variance = ...
        # ====

        # ====
        # your code
        # calculate coefficients of the distribution q(x_{t-1} | x_t, x_0) mean (use equation (2))
        self.xt_coef = ...
        self.x0_coef = ...
        # ====

        self.clip_x0 = clip_x0

    def get_x0(self, xt: th.Tensor, eps: th.Tensor, t: th.Tensor) -> th.Tensor:
        # ====
        # your code
        # get x_0 (use equations (4) and (2))
        ...
        # ====
        if self.clip_x0:
            x0 = x0.clamp(-1., 1.)
        return x0

    def q_posterior_mean_variance(
        self, xt: th.Tensor, eps: th.Tensor, t: th.Tensor
    ) -> Tuple[th.Tensor, th.Tensor]:
        # ====
        # your code
        # get mean and variance of the distribution q(x_{t-1} | x_t, x_0) mean (use equations (2) and (3))
        ...
        # ====
        return mean, variance

    def p_sample(self, xt: th.Tensor, eps: th.Tensor, t: th.Tensor) -> th.Tensor:
        # read this code carefully
        mean, variance = self.q_posterior_mean_variance(xt=xt, eps=eps, t=t)
        noise = th.randn_like(xt, device=xt.device)

        nonzero_mask = th.ones_like(t)  # to not add any noise while predicting x0
        nonzero_mask[t == 0] = 0
        nonzero_mask = _extract_into_tensor(
            nonzero_mask, th.arange(nonzero_mask.shape[0]), xt.shape
        )
        nonzero_mask = nonzero_mask.to(xt.device)
        sample = mean + nonzero_mask * variance.sqrt() * noise
        return sample.float()

### Model for noise prediction

Now we have to implement model with weights $\theta$, which parametrize backward process.

Model should not be complex, just several linear layers are enough.

Don't forget to take into account classes $y$ and timesteps $t$.

Model is supposed to predict noise $\epsilon$: $\epsilon_\theta(x_t, t, y)$.

In [None]:
class ConditionalMLP(nn.Module):
    def __init__(self, d_in: int, T: int, n_classes: int, hidden_dim: Optional[int]=128):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.x_proj = nn.Linear(d_in, self.hidden_dim)
        self.t_proj = nn.Embedding(T, self.hidden_dim)
        self.y_embed = nn.Embedding(n_classes, self.hidden_dim)
        self.backbone = nn.Sequential(
            nn.Linear(self.hidden_dim, 2 * self.hidden_dim),
            nn.GELU(),
            nn.Linear(2 * self.hidden_dim, d_in)
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, x, t, y):
        '''
        :x input, e.g. images
        :t 1d th.gTensor of timesteps
        :y 1d th.LongTensor of class labels
        '''
        x = self.x_proj(x)
        t = self.t_proj(t.int())
        y = self.y_embed(y)
        x = x + t + y
        x = F.gelu(x)
        return self.backbone(x)

Let's unite all implemented entities in the DDPM class

In [None]:
class DDPM(nn.Module):
    def __init__(
        self,
        betas: th.Tensor,
        model: nn.Module,
        clip_x0: Optional[bool] = False,
        shape: Optional[th.Tensor] = None,
        update_ema_after: Optional[int] = 2000,
    ) -> None:
        super().__init__()

        self.forward_diffusion = ForwardDiffusion(betas=betas)
        self.reverse_diffusion = ReverseDiffusion(betas=betas, clip_x0=clip_x0)
        self.model = model
        self.num_timesteps = len(betas)

        self.ema = deepcopy(model)
        self.update_ema_after = update_ema_after
        self.ema_counter = 0

        self.register_buffer("betas", betas)
        self.register_buffer("clip_x0", th.tensor(clip_x0, dtype=bool))
        self.register_buffer("shape", shape)

    @property
    def device(self) -> None:
        return next(self.parameters()).device

    @th.no_grad()
    def sample(self, y: th.Tensor) -> th.Tensor:
        assert self.shape is not None
        if self.ema_counter < self.update_ema_after:
            model = self.model
        else:
            model = self.ema

        num_samples = y.shape[0]
        x = th.randn((num_samples, *self.shape), device=self.device, dtype=th.float32)
        indices = list(range(self.num_timesteps))[::-1]

        for i in tqdm(indices):
            t = th.tensor([i] * num_samples, device=x.device)
            # ====
            # your code
            # 1) get epsilon from the model
            # 2) sample from the reverse diffusion
            eps = ...
            x = ...
            # ====
        return x, y

    def train_loss(self, x0: th.Tensor, y: th.Tensor) -> th.Tensor:
        self._update_ema()
        if self.shape is None:
            self.shape = th.tensor(list(x0.shape)[1:], device="cpu")
        t = th.randint(0, self.num_timesteps, size=(x0.size(0),), device=x0.device)
        noise = th.randn_like(x0)

        # ====
        # your code
        # 1) get x_t
        # 2) get epsilon from the model
        x_t = ...
        eps = ...
        # ====
        loss = F.mse_loss(eps, noise)
        return loss

    def _update_ema(self):
        self.ema_counter += 1
        if self.ema_counter < self.update_ema_after:
            return
        if self.ema_counter == self.update_ema_after:
            self.ema.load_state_dict(self.model.state_dict())

        ema_weight = 0.99
        new_ema_state_dict = self.ema.state_dict()
        model_state_dict = self.model.state_dict()
        for key, val in new_ema_state_dict.items():
            if isinstance(val, th.Tensor):
                new_ema_state_dict[key] = (
                    ema_weight * new_ema_state_dict[key] +
                    (1 - ema_weight) * model_state_dict[key]
                )
        self.ema.load_state_dict(new_ema_state_dict)

    @classmethod
    def from_pretrained(cls: "DDPM", model: nn.Module, ckpt_path: str) -> "DDPM":
        ckpt = th.load(ckpt_path)
        model_state_dict = {
            re.sub("ema.", "", re.sub("model.", "", key)):
            val for key, val in ckpt.items() if "ema." in key
        }
        model.load_state_dict(model_state_dict)
        return cls(
            betas=ckpt["betas"],
            model=model,
            clip_x0=ckpt["clip_x0"],
            shape=ckpt["shape"],
        )

### Training

In [None]:
def train_model(
    ddpm: DDPM,
    dataloader: DataLoader,
    lr: float,
    weight_decay: float,
    n_iters: int,
    device: str = "cpu",
    log_every: int = 500
):
    ddpm = ddpm.to(device)

    optimizer = th.optim.AdamW(
        ddpm.model.parameters(), lr=lr, weight_decay=weight_decay
    )
    step = 0
    curr_loss_gauss = 0.0
    curr_count = 0
    optimizer.zero_grad()
    data_iter = iter(dataloader)
    while step < n_iters:
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)
        x, y = batch["x"].to(device), batch["y"].to(device)

        loss = ddpm.train_loss(x, y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        curr_count += len(x)
        curr_loss_gauss += loss.item() * len(x)

        if (step + 1) % log_every == 0:
            gloss = np.around(curr_loss_gauss / curr_count, 4)
            print(f"Step {(step + 1)}/{n_iters} Loss: {gloss}")
            curr_count = 0
            curr_loss_gauss = 0.0

        step += 1

In [None]:
T = 100
# ====
# your code
# choose these parameters
BATCH_SIZE = 1024
LR = 0.01
WEIGHT_DECAY = 0.0
N_ITERS = 15000
# ====

model = ConditionalMLP(d_in=2, T=T, n_classes=2)
device = "cpu" # cpu is enough

if not os.path.exists(SAVE_DPPM_SWISS_PATH):
    th.manual_seed(0)
    random.seed(0)

    ddpm = DDPM(betas=get_beta_schedule(T), model=model)
    dataloader = get_labeled_data_loader(X, y, batch_size=BATCH_SIZE, shuffle=True)

    train_model(
        ddpm=ddpm,
        dataloader=dataloader,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        n_iters=N_ITERS,
        device=device
    )
    th.save(ddpm.to("cpu").state_dict(), SAVE_DPPM_SWISS_PATH)
else:
    ddpm = DDPM.from_pretrained(model, SAVE_DPPM_SWISS_PATH)

_ = ddpm.to(device)

Now let's take a look on a data our model learned to generate.

In [None]:
num_samples = X.shape[0]
ys = th.randint(0, 2, size=(num_samples,), device=device)
Xs, ys = ddpm.sample(ys)
plt.figure(figsize=(4, 4))
sns.scatterplot(x=Xs[:, 0], y=Xs[:, 1], hue=ys)

Let's look at the denoising process

In [None]:
x = th.randn(y.shape[0], 2, requires_grad=False)
y = th.tensor(y, requires_grad=False, dtype=th.long)

plot_n_steps = 8
noise_n_steps = 3
_, axs = plt.subplots(1, plot_n_steps, figsize=(plot_n_steps * 4, 4))

idx = 0
t_to_plot = list(np.round(np.linspace(ddpm.num_timesteps - 1, 15, num=noise_n_steps)).astype("int")) + \
            list(np.round(np.linspace(10, 0, num=plot_n_steps - noise_n_steps)).astype("int"))
for i in tqdm(range(ddpm.num_timesteps - 1, -1, -1)):
    t = th.tensor(i, dtype=th.long, requires_grad=False).expand(x.shape[0])
    with th.no_grad():
        eps = ddpm.model(x, t, y)
        x =  ddpm.reverse_diffusion.p_sample(x, eps, t)
    if i in t_to_plot:
        sns.scatterplot(x=x[:,0], y=x[:,1], hue=y, ax=axs[idx])
        axs[idx].set(title=i)
        idx += 1

### MNIST

Now we will apply diffusion model to the MNIST dataset.

In [None]:
from torchvision.datasets.mnist import MNIST

def mnist_to_train_range(X):
    return ((X.astype("float32") / 255.) - 0.5) * 2

def mnist_from_train_range(X):
    return (((X.astype("float32") + 1.0) / 2) * 255.).astype("int")

dataset = MNIST("./datasets", download=True, train=True)
X = dataset.data.numpy().astype("float32")[:, None]
y = dataset.targets.numpy()
mnist_loader = get_labeled_data_loader(mnist_to_train_range(X), y, batch_size=64)

Let's plot several instances of the given dataset.

In [None]:
def show_images(images, ys, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is th.Tensor:
        images = images.detach().cpu().numpy()
        ys = ys.detach().cpu().numpy()

    # Defining number of rows and columns
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)
    fig = plt.figure(figsize=(cols*2, rows*2))

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                plt.imshow(images[idx][0], cmap="gray")
                plt.title(f"{int(ys[idx])}")
                plt.tick_params(bottom = False, labelbottom=False)
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()

def show_first_batch(loader, batch_size=16):
    for batch in loader:
        show_images(batch["x"][:batch_size], batch["y"][:batch_size], "Images in the first batch")
        break

In [None]:
show_first_batch(mnist_loader)

Let's take a look on the forward process for the MNIST images.

In [None]:
forward_diffusion = ForwardDiffusion(betas=get_beta_schedule(1000))
plot_n_steps = 9

images = next(iter(mnist_loader))["x"][:1, 0]

_, axs = plt.subplots(1, plot_n_steps, figsize=(6 * plot_n_steps, 6))
t_to_plot = list(np.round(np.linspace(0, forward_diffusion.num_timesteps - 1, num=plot_n_steps)).astype("int"))
for i,t in enumerate(t_to_plot):
    x = forward_diffusion.q_sample(
        x0=images,
        t=(th.ones(images.shape[0], device=images.device) * t).long(),
    )
    axs[i].imshow(x[0], cmap="gray")
    axs[i].set(title=t)

Noise predicting model is written for you.

You can find details in the attached `utils.py` file.

In [None]:
# Downloading pretrained model
!gdown -O "models/ddpm_mnist.pt" "https://drive.google.com/uc?id=1fSPB08M6aBNmhjRgSn3qpdq5hXl1Xhao"

In [None]:
T = 1000
# ====
# your code
# choose these parameters
BATCH_SIZE = 1024
LR = 0.01
WEIGHT_DECAY = 0.0
N_ITERS = 5000
# ====

model = MyUNet()
device = th.device("cuda" if th.cuda.is_available() else "cpu")

if not os.path.exists(SAVE_DPPM_MNIST_PATH):
    th.manual_seed(0)
    random.seed(0)

    ddpm = DDPM(betas=get_beta_schedule(T), model=model, clip_x0=True)
    dataloader = get_labeled_data_loader(mnist_to_train_range(X), y, batch_size=BATCH_SIZE, shuffle=True)

    train_model(
        ddpm=ddpm,
        dataloader=dataloader,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        n_iters=N_ITERS,
        device=device
    )
    th.save(ddpm.to("cpu").state_dict(), SAVE_DPPM_MNIST_PATH)
else:
    ddpm = DDPM.from_pretrained(model, SAVE_DPPM_MNIST_PATH)

_ = ddpm.to(device)

Let's draw some samples with the model.

In [None]:
num_samples = 16
ys = th.randint(10, size=(num_samples,), device=device)
Xs, ys = ddpm.sample(y=ys)
show_images(Xs, ys)

## Faster sampling with DDPM

In the previous task it took us about 7 seconds to generate a batch of images with our diffusion model (even using a gpu).

That's not a big deal for now. But the larger images in our dataset the more time is required for sampling (by a large factor).

And this drawback of diffusion models can't be resolved generally with using more gpus, since it requires iterative sampling (in the previous task we consequently infered our model 1000 times).

There are several techniques to alleviate this drawback.

We are going to implement one of them, which was proposed [here](https://arxiv.org/abs/2102.09672).

So, assume we have already trained a model to "reverse" a Markov chain of length T.

Let's imagine, that it corresponds to a shorter Markov chain
$\{S_0 = 0, S_1, \ldots, S_{T'-1}, S_{T'} = T\}$, where $T' < T$.

Then in order to generate samples, we are to do $T' (< T)$ inferences of our model instead of T.

The only thing we are to make sure about is that $q^{new}(x_i) = q(x_{S_i})$.

That's condition is satysfied if $q^{new}(x_i | x_0) = q(x_{S_i} | x_0)$.

From that condition we can get betas for the new (the shorter one) diffusion process
$$
    \begin{align}
        \beta_i^{new} = 1 - \frac{
            \bar{\alpha}_{S_i}
        }{
            \bar{\alpha}_{S_{i-1}}
        }
    \end{align}
$$

Since diffusion model is fully describes by betas, we don't have to change anything in the sampling process.

Though we have to slightly adjust pretrained model $\epsilon$ inputs:
instead of $\epsilon(x_i, i)$ (where $i$ is a timestep for the new diffusion process) we have to use $\epsilon(x_i, S_i)$.

You are to fill in the gaps marked with `you code`

In [None]:
class SpacedDDPM(DDPM):
    """
    A diffusion process which can skip steps in a base diffusion process.

    :param use_timesteps: a collection (sequence or set) of timesteps from the
                          original diffusion process to retain.
    :param kwargs: the kwargs to create the base diffusion process.
    """

    def __init__(self, use_timesteps: Optional[Tuple]=None, **kwargs):
        if use_timesteps is None:
            use_timesteps = list(range(len(kwargs["betas"])))
        self.use_timesteps = set(use_timesteps)
        timestep_map = []
        original_num_steps = len(kwargs["betas"])

        base_kwargs = deepcopy(kwargs)
        base_ddpm = DDPM(**base_kwargs)
        last_alpha_cumprod = 1.0
        new_betas = []
        for i, alpha_cumprod in enumerate(base_ddpm.forward_diffusion.alphas_cumprod):
            if i in self.use_timesteps:
                # ====
                # your code
                # 1) update new_betas (use equation (1))
                # 2) update timestep_map (it is a mapping S: i -> S_i)
                ...
                # ====
        kwargs["betas"] = th.tensor(new_betas)
        kwargs["model"] = _WrappedModel(kwargs["model"], timestep_map)
        super().__init__(**kwargs)

    @classmethod
    def from_pretrained(cls, *args, use_timesteps: Optional[Tuple]=None, **kwargs):
        ddpm = DDPM.from_pretrained(*args, **kwargs)
        return cls(
            use_timesteps = use_timesteps,
            betas = ddpm.betas,
            model = ddpm.model,
            clip_x0 = ddpm.clip_x0,
            shape = ddpm.shape,
            update_ema_after = ddpm.update_ema_after,
        )

class _WrappedModel(nn.Module):
    def __init__(self, model: nn.Module, timestep_map: List):
        super().__init__()
        self.model = model
        self.timestep_map = timestep_map

    @property
    def device(self):
        return next(self.model.parameters()).device

    def __call__(self, x: th.Tensor, t: th.Tensor, *args, **kwargs) -> th.Tensor:
        # ====
        # your code
        # compute new_t using self.timestep_map
        new_t = ...
        # ====
        return self.model(x, new_t, *args, **kwargs)

Let's apply it to our datasets

### Swiss Roll

In [None]:
assert os.path.exists(SAVE_DPPM_SWISS_PATH)
T = 100
model = ConditionalMLP(d_in=2, T=T, n_classes=2)
device = "cpu"

In [None]:
for num_timesteps in [100, 50, 25, 10]:
    use_timesteps = np.linspace(1, T, num=num_timesteps).astype("int")
    spaced_ddpm = SpacedDDPM.from_pretrained(
        use_timesteps=use_timesteps, model=model, ckpt_path=SAVE_DPPM_SWISS_PATH
    )
    _ = spaced_ddpm.to(device)

    num_samples = 2000
    ys = th.randint(0, 2, size=(num_samples,), device=device)
    Xs, ys = spaced_ddpm.sample(ys)

    plt.figure(figsize=(4, 4))
    sns.scatterplot(x=Xs[:, 0], y=Xs[:, 1], hue=ys)
    plt.title(f"{num_timesteps = }/100")
    plt.show()

### MNIST

In [None]:
# Downloading pretrained model
!gdown -O "models/ddpm_mnist.pt" "https://drive.google.com/uc?id=1fSPB08M6aBNmhjRgSn3qpdq5hXl1Xhao"

In [None]:
assert os.path.exists(SAVE_DPPM_MNIST_PATH)
T = 1000
model = MyUNet()
device = th.device("cuda" if th.cuda.is_available() else "cpu")

In [None]:
num_timesteps = 100
use_timesteps = np.linspace(1, T, num=num_timesteps).astype("int")
spaced_ddpm = SpacedDDPM.from_pretrained(
    use_timesteps=use_timesteps, model=model, ckpt_path=SAVE_DPPM_MNIST_PATH
)
_ = spaced_ddpm.to(device)

num_samples = 16
ys = th.randint(2, size=(num_samples,), device=device)
Xs, ys = spaced_ddpm.sample(y=ys)
show_images(Xs, ys)

## Classifier-free guidance

We are able to generate samples relevant to given label by feeding the model this label as an input.

First of all, there is not guarantees that we will get something relevant.

Secondly, there is no handle to increase/decrease influence of a label.

Now we are going to implement a technique reducing these drawbacks.

Quick reminder of the method.

Diffusion Models implicitly learn score-functions of marginal distributions $q(x_t)$ and sample from them.

Let's make them sample from $q(x_t | y) = \frac{q(y | x_t) q(x_t)}{q(y)}$.

From that we have:
$
    \nabla_{x_t} \log q(x_t | y)
    =
    \nabla_{x_t} \log q(y | x_t)
    +
    \nabla_{x_t} \log q(x_t)
$

In order to increase label influence, we scale first additive:
$$
    \nabla_{x_t} \log q(x_t | y)
    =
    s \cdot \nabla_{x_t} \log q(y | x_t)
    +
    \nabla_{x_t} \log q(x_t)
$$

Using score-model parametrization of diffusion models, we can approximate the aforementined conditional score-function with noise model:
$$
    -\frac{\hat{\epsilon}_\theta (x_t, y)}{\sqrt{1 - \bar{\alpha}_t}}
    =
    -s \cdot \frac{\epsilon_\theta (x_t, y) - \epsilon_\theta (x_t, \varnothing)}{\sqrt{1 - \bar{\alpha}_t}}
    -\frac{\epsilon_\theta (x_t, \varnothing)}{\sqrt{1 - \bar{\alpha}_t}}
$$
where $\epsilon_\theta (x_t, \varnothing)$ is an unconditional diffusion model.

As a result we have
$$
    \begin{align}
        \hat{\epsilon}_\theta (x_t, y)
        &=
        s \cdot (\epsilon_\theta (x_t, y) - \epsilon_\theta (x_t, \varnothing))
        +
        \epsilon_\theta (x_t, \varnothing)\\
        \hat{\epsilon}_\theta (x_t, y)
        &=
        \epsilon_\theta (x_t, \varnothing) \cdot (1 - s)
        + \epsilon_\theta (x_t, y) \cdot s
    \end{align}
$$

Usually unconditional and conditional models are actually a one model.

In this case `null_label` is considered as an actual label not equal to any others.

In [None]:
class DDPM_CFG(SpacedDDPM):
    @th.no_grad()
    def sample(
        self, y: th.Tensor, guidance_scale: float=0., null_label: int=2
    ):
        assert self.shape is not None
        num_samples = y.shape[0]
        x = th.randn((num_samples, *self.shape), device=self.device, dtype=th.float32)
        indices = list(range(self.num_timesteps))[::-1]

        for i in tqdm(indices):
            t = th.tensor([i] * num_samples, device=x.device)
            # ====
            # your code
            # 1) get epsilon with hat using the model
            # 2) sample from the reverse diffusion
            eps_hat = self._predict_eps_hat(x, t, guidance_scale, null_label)
            x = ...
            # ====
        return x, y

    def _predict_eps_hat(
        self, x: th.Tensor, t: th.Tensor, y: th.Tensor, guidance_scale: float, null_label: int
    ):
        null_y = null_label * th.ones_like(y)

        if self.ema_counter < self.update_ema_after:
            model = self.model
        else:
            model = self.ema

        # ====
        # your code
        # get epsilon with hat using the model (use equation (2))
        return ...
        # ====

### Swiss Roll

In [None]:
X, y = make_swiss_dataset(2000)

In [None]:
T = 100
# ====
# your code
# choose these parameters
BATCH_SIZE = 1024
LR = 0.01
WEIGHT_DECAY = 0.0
N_ITERS = 15000
# ====

model = ConditionalMLP(d_in=2, T=T, n_classes=2+1)
device = "cpu" # cpu is enough

if not os.path.exists(SAVE_DPPM_CFG_SWISS_PATH):
    th.manual_seed(0)
    random.seed(0)

    dataloader = get_labeled_data_loader(X, y, batch_size=BATCH_SIZE, shuffle=True, drop_label=0.4)
    ddpm = DDPM_CFG(betas=get_beta_schedule(T), model=model)

    train_model(
        ddpm=ddpm,
        dataloader=dataloader,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        n_iters=N_ITERS,
        device=device
    )
    th.save(ddpm.to("cpu").state_dict(), SAVE_DPPM_CFG_SWISS_PATH)
else:
    ddpm = DDPM_CFG.from_pretrained(model=model, ckpt_path=SAVE_DPPM_CFG_SWISS_PATH)

_ = ddpm.to(device)

In [None]:
num_samples = X.shape[0]
ys = th.randint(0, 2, size=(num_samples,), device=device)
for guidance_scale in [0.0, 0.5, 1.0, 2.0, 3.0]:
    Xs, ys = ddpm.sample(ys, guidance_scale=guidance_scale, null_label=2)
    plt.figure(figsize=(4, 4))
    sns.scatterplot(x=Xs[:, 0], y=Xs[:, 1], hue=ys)
    plt.title(f"{guidance_scale = }")
    plt.show()

In [None]:
num_timesteps = 50
use_timesteps = np.linspace(1, T, num=num_timesteps).astype("int")
ddpm = DDPM_CFG.from_pretrained(
    use_timesteps=use_timesteps, model=model, ckpt_path=SAVE_DPPM_CFG_SWISS_PATH
)
ddpm.to(device)

num_samples = X.shape[0]
ys = th.randint(0, 2, size=(num_samples,), device=device)
for guidance_scale in [0.0, 0.5, 1.0, 2.0, 3.0]:
    Xs, ys = ddpm.sample(ys, guidance_scale=guidance_scale, null_label=2)
    plt.figure(figsize=(4, 4))
    sns.scatterplot(x=Xs[:, 0], y=Xs[:, 1], hue=ys)
    plt.title(f"{guidance_scale = }")
    plt.show()

### MNIST

In [None]:
def show_images(images, ys, title=""):
    """Shows the provided images as sub-pictures in a square"""

    # Converting images to CPU numpy arrays
    if type(images) is th.Tensor:
        images = images.detach().cpu().numpy()
        ys = ys.detach().cpu().numpy()

    # Defining number of rows and columns
    rows = int(len(images) ** (1 / 2))
    cols = round(len(images) / rows)
    fig = plt.figure(figsize=(cols*2, rows*2))

    # Populating figure with sub-plots
    idx = 0
    for r in range(rows):
        for c in range(cols):
            fig.add_subplot(rows, cols, idx + 1)

            if idx < len(images):
                plt.imshow(images[idx][0], cmap="gray")
                plt.title(f"{int(ys[idx])}")
                plt.tick_params(bottom = False, labelbottom=False)
                idx += 1
    fig.suptitle(title, fontsize=30)

    # Showing the figure
    plt.show()

In [None]:
from torchvision.datasets.mnist import MNIST

def mnist_to_train_range(X):
    return ((X.astype("float32") / 255.) - 0.5) * 2

def mnist_from_train_range(X):
    return (((X.astype("float32") + 1.0) / 2) * 255.).astype("int")

dataset = MNIST("./datasets", download=True, train=True)
X = dataset.data.numpy().astype("float32")[:, None]
y = dataset.targets.numpy()
mnist_loader = get_labeled_data_loader(mnist_to_train_range(X), y, batch_size=64)

In [None]:
# Downloading pretrained model
!gdown -O "models/ddpm_cfg_mnist.pt" "https://drive.google.com/uc?id=1DoLq4PYoef5fo-tewy22o5P89no5V6NW"

In [None]:
T = 1000
# ====
# your code
# choose these parameters
BATCH_SIZE = 1024
LR = 0.01
WEIGHT_DECAY = 0.0
N_ITERS = 5000
# ====

model = MyUNet(use_null_cond=True)
device = th.device("cuda" if th.cuda.is_available() else "cpu")

if not os.path.exists(SAVE_DPPM_CFG_MNIST_PATH):
    th.manual_seed(0)
    np.random.seed(0)
    random.seed(0)

    ddpm = DDPM_CFG(betas=get_beta_schedule(T), model=model, clip_x0=True)
    dataloader = get_labeled_data_loader(mnist_to_train_range(X), y, batch_size=BATCH_SIZE, shuffle=True, drop_label=0.4)

    train_model(
        ddpm=ddpm,
        dataloader=dataloader,
        lr=LR,
        weight_decay=WEIGHT_DECAY,
        n_iters=N_ITERS,
        device=device
    )
    th.save(ddpm.to("cpu").state_dict(), SAVE_DPPM_CFG_MNIST_PATH)
else:
    num_timesteps = 100
    use_timesteps = np.linspace(1, T, num=num_timesteps).astype("int")
    ddpm = DDPM_CFG.from_pretrained(use_timesteps=use_timesteps, model=model, ckpt_path=SAVE_DPPM_CFG_MNIST_PATH)

_ = ddpm.to(device)

In [None]:
num_samples = 16
ys = th.randint(10, size=(num_samples,), device=device)
Xs, ys = ddpm.sample(ys, guidance_scale=1.0, null_label=10)
show_images(Xs, ys)

In [None]:
num_samples = 16
ys = th.randint(10, size=(num_samples,), device=device)
Xs, ys = ddpm.sample(ys, guidance_scale=5.0, null_label=10)
show_images(Xs, ys)

In [None]:
num_samples = 16
ys = th.randint(10, size=(num_samples,), device=device)
Xs, ys = ddpm.sample(ys, guidance_scale=0.0, null_label=10)
show_images(Xs, ys)