## simple diffusion: End-to-end diffusion for high resolution images
### Unofficial PyTorch Implementation by Gian Favero

**Simple diffusion: End-to-end diffusion for high resolution images**  
[Emiel Hoogeboom](https://arxiv.org/search/cs?searchtype=author&query=Hoogeboom,+E), [Jonathan Heek](https://arxiv.org/search/cs?searchtype=author&query=Heek,+J), [Tim Salimans](https://arxiv.org/search/cs?searchtype=author&query=Salimans,+T)
https://arxiv.org/abs/2301.11093

GitHub Repository: https://github.com/faverogian/simpleDiffusion/blob/main/README.md

In [None]:
import torch
import torch.nn as nn
from torch.special import expm1
import math
from accelerate import Accelerator
from huggingface_hub import create_repo, upload_folder
from pathlib import Path
import os
from tqdm import tqdm
from ema_pytorch import EMA
import matplotlib.pyplot as plt

# helpers
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def clip(x):
    """
    Function to clip the input tensor x to the range [-1, 1].

    Args:
    x (torch.Tensor): The input tensor to clip.

    Returns:
    x (torch.Tensor): The clipped tensor.
    """
    return torch.clamp(x, -1, 1)

### Simple Diffusion
We define a class that allows the training of a diffusion model using the simple diffusion paradigm - highlighted by the introduction of a shifted cosine noise schedule. The intuition is relatively simple: find a noise schedule that works at a baseline resolution and shift it proportionally to a new image size.

For example, we know from a littany of previous experiments in literature that a cosine noise schedule works well on 64x64 images, but fails to scale effectively to images of larger size. Instead of sending up a prayer and hoping that the same results can be achieved on larger images (say 256x256), we shift the noise schedule proportionally to the differnce in image size, leading to more consistent results across the board.

A sample class definition is seen below:

In [None]:
class simpleDiffusion(nn.Module):
    def __init__(
        self, 
        unet,
        image_size,
        noise_size=64,
        pred_param='v', 
        schedule='shifted_cosine', 
        steps=512
    ):
        super().__init__()

        # Training objective
        assert pred_param in ['v', 'eps'], "Invalid prediction parameterization. Must be 'v' or 'eps'"
        self.pred_param = pred_param

        # Sampling schedule
        assert schedule in ['cosine', 'shifted_cosine'], "Invalid schedule. Must be 'cosine' or 'shifted_cosine'"
        self.schedule = schedule
        self.noise_d = noise_size
        self.image_d = image_size

        # Model
        assert isinstance(unet, nn.Module), "Model must be an instance of torch.nn.Module."
        self.model = unet

        num_params = sum(p.numel() for p in self.model.parameters())
        print(f"Number of parameters: {num_params}")

        # Steps
        self.steps = steps

### The Forward Diffusion Process

Since we are working with Gaussian noise in a Markovian manner, a noised image at any time $t$ can be derived from a base image, $z_t \sim p(z_t | z_s, x)$, where $0 \leq s \leq t \leq 1$. Due to the properties of Gaussian noise addition, we do not have to repeatedly sample from $p(z_t | z_s, x)$, but rather directly arrive at $z_t$ from the base image $x$. This forward process can be defined as: 
\begin{align*}
    \bm{z}_t = \alpha_t \bm{x} + \sigma_t \bm{\epsilon}; \quad \bm{\epsilon} \sim \mathcal{N}(\bm{\epsilon}; \bm{0}, \bm{\text{I}})
\end{align*}
In the above Gaussian noising process $\alpha_t$ and $\sigma_t$ are strictly positive scalar-valued functions of $t \in [0,1]$. We implement a method called "diffuse" to carry out this forward process for a given $\alpha_t$ and $\sigma_t$.

In [None]:
def diffuse(x, alpha_t, sigma_t):
    """
    Function to diffuse the input tensor x to a timepoint t with the given alpha_t and sigma_t.

    Args:
    x (torch.Tensor): The input tensor to diffuse.
    alpha_t (torch.Tensor): The alpha value at timepoint t.
    sigma_t (torch.Tensor): The sigma value at timepoint t.

    Returns:
    z_t (torch.Tensor): The diffused tensor at timepoint t.
    eps_t (torch.Tensor): The noise tensor at timepoint t.
    """
    eps_t = torch.randn_like(x)

    z_t = alpha_t * x + sigma_t * eps_t

    return z_t, eps_t

### The Log-SNR Schedule

In the above Gaussian noising process $\alpha_t$ and $\sigma_t$ are related through a signal-to-noise ratio in the image. The $\log$-SNR is additionally defined as $\lambda = \log \frac{\alpha_t^2}{\sigma_t^2}$, where $\lambda$ is strictly monotonically decreasing in time, such that $\lambda_{\max}$ occurs at $t=0$ and $\lambda_{\min}$ occurs at $t=1$ so that $\bm{z}_T = \mathcal{N}(\bm{0}, \bm{\text{I}})$. At a given timepoint $t$, we use a function $\lambda = f_\lambda (t)$ from which we obtain $\alpha_t$ and $\sigma_t$. The forward process (or the destruction of $\bm{x}$) is commonly defined to be variance-preserving, imposing the constraint $\alpha_t^2 + \sigma_t^2 = 1$. This implies:
\begin{align*}
    \alpha(t)^2 &= \text{sigmoid}(\lambda) \\
    \sigma(t)^2 &= \text{sigmoid}(-\lambda)
\end{align*}

There are many schools of thought as to the best noise schedule for training diffusion models at various resolutions and image types (natural images, medical, etc.). A common choice is the cosine noise schedule as introduced by Dhariwal and Nichol in their "Diffusion Models beat GANs on Image Synthesis" paper:
\begin{align*}
    \lambda(t) &= \log \frac{\alpha_t^2}{\sigma_t^2} \\
    &= -2 \log (\tan (\pi t / 2))
\end{align*}
We implement this in the method below:

In [None]:
def logsnr_schedule_cosine(self, t, logsnr_min=-15, logsnr_max=15):
    """
    Function to compute the logSNR schedule at timepoint t with cosine:

    logSNR(t) = -2 * log (tan (pi * t / 2))

    Taking into account boundary effects, the logSNR value at timepoint t is computed as:

    logsnr_t = -2 * log(tan(t_min + t * (t_max - t_min)))

    Args:
    t (int): The timepoint t.
    logsnr_min (int): The minimum logSNR value.
    logsnr_max (int): The maximum logSNR value.

    Returns:
    logsnr_t (float): The logSNR value at timepoint t.
    """
    logsnr_max = logsnr_max + math.log(self.noise_d / self.image_d)
    logsnr_min = logsnr_min + math.log(self.noise_d / self.image_d)
    t_min = math.atan(math.exp(-0.5 * logsnr_max))
    t_max = math.atan(math.exp(-0.5 * logsnr_min))

    logsnr_t = -2 * log(torch.tan(torch.tensor(t_min + t * (t_max - t_min))))

    return logsnr_t

### The Shifted-Cosine Noise Schedule

As demonstrated in "End-to-end diffusion for high resolution images" by Hoogeboom et al., more efficient training of diffusion models can be achieved when the noise schedule is adjusted for the image resolution. In the log-space, a proportional adjustment can be obtained by:
\begin{align*}
    \lambda_{\text{shifted}}(t) = \lambda(t) + 2 \log \left( \frac{\text{base dimension}}{\text{image dimension}} \right)
\end{align*}
We implement this below:

In [None]:
def logsnr_schedule_cosine_shifted(self, t):
    """
    Function to compute the logSNR schedule at timepoint t with shifted cosine:

    logSNR_shifted(t) = logSNR(t) + 2 * log(noise_d / image_d)

    Args:
    t (int): The timepoint t.
    image_d (int): The image dimension.
    noise_d (int): The noise dimension.

    Returns:
    logsnr_t_shifted (float): The logSNR value at timepoint t.
    """
    logsnr_t = self.logsnr_schedule_cosine(t)
    logsnr_t_shifted = logsnr_t + 2 * math.log(self.noise_d / self.image_d)

    return logsnr_t_shifted

### The Loss Function

Training a diffusion model is a repeated process of:
1. Add noise to an image according to a given timestep, $t$
2. Predict (by some parameterization) the noise that was added to the original image

We can attempt to predict the noise that was added to the image in various ways. Common choices of parameterization are $x$-, $\epsilon$- and $v$-prediction. Given $\bm{z}_t = \alpha_t \bm{x} + \sigma_t \bm{\epsilon}$, an $\epsilon$-prediction model tries to directly estimate the noise, while an $x$-prediction model attempts to directly recover the original image. A more modern take that came alongside progressive distillation (Ho et al.) is $v$-prediction, which estimates both the original image and the noise added: $v = \alpha_t \bm{x} + \sigma_t \epsilon$, finding benefits in stability.

All three are related, though imply various weightings to the loss function. Thus, they are not interchangeable, but rather a means to the same end. In our loss function, we provide the option to use a (shifted) cosine noise schedule, as well as the option to choose epsilon- or v-prediction. The loss is implemented as an optimization of an MSE of the parameterization of your choice.

In [None]:
def loss(self, x):
    """
    A function to compute the loss of the model. The loss is computed as the mean squared error
    between the predicted noise tensor and the true noise tensor. Various prediction parameterizations
    imply various weighting schemes as outlined in Kingma et al. (2023)

    Args:
    x (torch.Tensor): The input tensor.

    Returns:
    loss (torch.Tensor): The loss value.
    """
    t = torch.rand(x.shape[0])

    if self.schedule == 'cosine':
        logsnr_t = self.logsnr_schedule_cosine(t)
    elif self.schedule == 'shifted_cosine':
        logsnr_t = self.logsnr_schedule_cosine_shifted(t)

    logsnr_t = logsnr_t.to(x.device)
    alpha_t = torch.sqrt(torch.sigmoid(logsnr_t)).view(-1, 1, 1, 1).to(x.device)
    sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t)).view(-1, 1, 1, 1).to(x.device)
    z_t, eps_t = self.diffuse(x, alpha_t, sigma_t)
    pred = self.model(z_t, logsnr_t)

    if self.pred_param == 'v':
        eps_pred = sigma_t * z_t + alpha_t * pred
    else: 
        eps_pred = pred

    # Apply min-SNR weighting (https://arxiv.org/pdf/2303.09556)
    snr = torch.exp(logsnr_t).clamp_(max = 5)
    if self.pred_param == 'v':
        weight = 1 / (1 + snr)
    else:
        weight = 1 / snr

    weight = weight.view(-1, 1, 1, 1)

    loss = torch.mean(weight * (eps_pred - eps_t) ** 2)

    return loss

### Sampling

Estimating $\epsilon_0$, $x_0$, or $v_0$ from pure Gaussian noise is a relatively impossible task and the same simplification done in the forward diffusion process is not achievable in the reverse process when sampling. In theory, the forward process takes its course over a series of timesteps, typically between 512 and 1000, which through a Markovian lens is a repeated samping of $p(z_t | z_s, x)$ where $0 \leq s \leq t \leq 1$. In the reverse process, we repeatedly sample $z_s \sim p(z_s | z_t)$ over the same set of timesteps before arriving at a sampled image.

The following sampling step algorithm is derived from Gaussian math and is shown beautifully in Appendix A.4 of "Variational Diffusion Models" by Kingma et al.

In [None]:
@torch.no_grad()
def ddpm_sampler_step(self, z_t, pred, logsnr_t, logsnr_s):
    """
    Function to perform a single step of the DDPM sampler.

    Args:
    z_t (torch.Tensor): The diffused tensor at timepoint t.
    pred (torch.Tensor): The predicted value from the model (v or eps).
    logsnr_t (float): The logSNR value at timepoint t.
    logsnr_s (float): The logSNR value at the sampling timepoint s.

    Returns:
    z_s (torch.Tensor): The diffused tensor at sampling timepoint s.
    """
    c = -expm1(logsnr_t - logsnr_s)
    alpha_t = torch.sqrt(torch.sigmoid(logsnr_t))
    alpha_s = torch.sqrt(torch.sigmoid(logsnr_s))
    sigma_t = torch.sqrt(torch.sigmoid(-logsnr_t))
    sigma_s = torch.sqrt(torch.sigmoid(-logsnr_s))

    if self.pred_param == 'v':
        x_pred = alpha_t * z_t - sigma_t * pred
    elif self.pred_param == 'eps':
        x_pred = (z_t - sigma_t * pred) / alpha_t

    x_pred = self.clip(x_pred)

    mu = alpha_s * (z_t * (1 - c) / alpha_t + c * x_pred)
    variance = (sigma_s ** 2) * c

    return mu, variance

We make use of the above algorithm by repeatedly calling it in a loop over T timesteps. Beginning with pure Gaussian noise, we slowly recover an image, one step at a time, using stochasticity as a self-correction technique as we know predicting an image in one-shot is nearly impossible for modern U-Nets (or to train such a model in a reasonable time frame).

In [None]:
@torch.no_grad()
def sample(self, x):
    """
    Standard DDPM sampling procedure. Begun by sampling z_T ~ N(0, 1)
    and then repeatedly sampling z_s ~ p(z_s | z_t)

    Args:
    x_shape (tuple): The shape of the input tensor.

    Returns:
    x_pred (torch.Tensor): The predicted tensor.
    """
    z_t = torch.randn(x.shape).to(x.device)

    # Steps T -> 1
    for t in reversed(range(1, self.steps+1)):
        u_t = t / self.steps
        u_s = (t - 1) / self.steps

        if self.schedule == 'cosine':
            logsnr_t = self.logsnr_schedule_cosine(u_t)
            logsnr_s = self.logsnr_schedule_cosine(u_s)
        elif self.schedule == 'shifted_cosine':
            logsnr_t = self.logsnr_schedule_cosine_shifted(u_t)
            logsnr_s = self.logsnr_schedule_cosine_shifted(u_s)

        logsnr_t = logsnr_t.to(x.device)
        logsnr_s = logsnr_s.to(x.device)

        pred = self.model(z_t, logsnr_t)
        mu, variance = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_t), torch.tensor(logsnr_s))
        z_t = mu + torch.randn_like(mu) * torch.sqrt(variance)

    # Final step
    if self.schedule == 'cosine':
        logsnr_1 = self.logsnr_schedule_cosine(1/self.steps)
        logsnr_0 = self.logsnr_schedule_cosine(0)
    elif self.schedule == 'shifted_cosine':
        logsnr_1 = self.logsnr_schedule_cosine_shifted(1/self.steps)
        logsnr_0 = self.logsnr_schedule_cosine_shifted(0)

    logsnr_1 = logsnr_1.to(x.device)
    logsnr_0 = logsnr_0.to(x.device)

    pred = self.model(z_t, logsnr_1)
    x_pred, _ = self.ddpm_sampler_step(z_t, pred, torch.tensor(logsnr_1), torch.tensor(logsnr_0))
    
    x_pred = clip(x_pred)

    # Convert x_pred to the range [0, 1]
    x_pred = (x_pred + 1) / 2

    return x_pred

### Implementing a Training Loop

From here the process is fairly straightforward. We optimize the loss function of the model, typically chosen as a U-Net. The U-Net (conditioned on $t$) is taken as a noise predictor network is trained over a set of timesteps (again, typically 512 to 1000). Unfortunately, tracking the loss function during training is not so transparent with diffusion experiments. Training progress is typically monitored by evaluating samples at regular intervals through visual metrics or comparative ones such as FID.

In this training loop, distributed training and gradient accumulation is implemented via the HuggingFace Accelerate library, while EMA tracking is done via the ema_pytorch library. Sample usage in a training script can be found at the GitHub page for this project: https://github.com/faverogian/simpleDiffusion/.

In [None]:
def train_loop(self, config, optimizer, train_dataloader, lr_scheduler):
    """
    A function to train the model.

    Args:
    optimizer (torch.optim.Optimizer): The optimizer to use for training.
    """
    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
    )
    if accelerator.is_main_process:
        if config.output_dir is not None:
            os.makedirs(config.output_dir, exist_ok=True)
        if config.push_to_hub:
            repo_id = create_repo(
                repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
            ).repo_id

    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 
        self.model, optimizer, train_dataloader, lr_scheduler
    )

    # Create an EMA model
    ema = EMA(
        model,
        beta=0.9999,
        update_after_step=100,
        update_every=10
    )

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            x = batch["images"]

            with accelerator.accumulate(model):
                loss = self.loss(x)
                loss = loss.to(next(model.parameters()).dtype)
                accelerator.backward(loss)
                accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Update EMA model parameters
            ema.update()

            progress_bar.update(1)

        # After each epoch you optionally sample some demo images
        if accelerator.is_main_process:
            self.model = accelerator.unwrap_model(model)
            self.model.eval()

        # Make directory for saving images
            os.makedirs(os.path.join(config.output_dir, "images"), exist_ok=True)

            if epoch % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                sample = self.sample(x[0].unsqueeze(0))
                sample = sample.detach().cpu().numpy().transpose(0, 2, 3, 1)
                image_path = os.path.join(config.output_dir, "images", f"sample_{epoch}.png")
                plt.imsave(image_path, sample[0])
    
            # Save the EMA model to HuggingFace Hub
            if config.push_to_hub and epoch == config.num_epochs - 1:
                upload_folder(
                    repo_id=repo_id,
                    folder_path=config.output_dir,
                    commit_message="EMA model",
                )
                self.model.push_to_hub(config.hub_model_id, variant="fp16")
                torch.save(ema.ema_model.module.state_dict(), 'ema_model.pth')