# Noise Scheduling
Diffusion models start with a completely noised image from which they predict the noise that has to be removed so that the the initial image will be restored again. In the initial versions of diffusion denoising models, this was done in a linear fashion and usually over several hundred steps. But in modern implementations, alternative noise schedules are used that have different noise trajectories.

This notebook provides you with an intuitive introduction about the purpose and inner works of noise schedulers. It visualises different noise trajectories, offers you a comparison between them and will explain the effect of different parameter choices, and why finding the right combination is crucial for the training of diffusion models.

In [None]:
import matplotlib.pyplot as plt
import torch

from omegaconf import DictConfig
from diffusion.noise_schedulers import LinearNoiseScheduler, CosineNoiseScheduler, SigmoidNoiseScheduler

from utils import convert_image_to_normalized_tensor, convert_normalized_tensors_to_images, get_picsum_image

In [None]:
IMAGE_ID = 237
TRAJECTORY_LEN = 11
NUM_TIMESTEPS = TRAJECTORY_LEN - 1 


In [None]:
config_linear_noise_scheduler = DictConfig({"num_timesteps": NUM_TIMESTEPS})
config_cosine_noise_scheduler = DictConfig({"num_timesteps": NUM_TIMESTEPS, "start": 0.0, "end": 1.0, "tau": 1.0})
config_sigmoid_noise_scheduler = DictConfig({"num_timesteps": NUM_TIMESTEPS, "start": 0.0, "end": 3.0, "tau": 1.0})

linear_noise_scheduler = LinearNoiseScheduler(config_linear_noise_scheduler)
cosine_noise_scheduler = CosineNoiseScheduler(config_cosine_noise_scheduler)
sigmoid_noise_scheduler = SigmoidNoiseScheduler(config_sigmoid_noise_scheduler)

In [None]:
image = get_picsum_image(IMAGE_ID, 128)
image_normalized, mean, std = convert_image_to_normalized_tensor(image)
images_normalized = image_normalized.unsqueeze(0).repeat(TRAJECTORY_LEN, 1, 1, 1)
ts_traj = torch.arange(TRAJECTORY_LEN)

noised_images_normalized_linear = linear_noise_scheduler(images_normalized, ts_traj, return_noise=False)
noised_images_linear = convert_normalized_tensors_to_images(noised_images_normalized_linear, mean, std)

noised_images_normalized_cosine = cosine_noise_scheduler(images_normalized, ts_traj, return_noise=False)
noised_images_cosine = convert_normalized_tensors_to_images(noised_images_normalized_cosine, mean, std)

noised_images_normalized_sigmoid = sigmoid_noise_scheduler(images_normalized, ts_traj, return_noise=False)
noised_images_sigmoid = convert_normalized_tensors_to_images(noised_images_normalized_sigmoid, mean, std)

## Noise Schedules
When training a diffusion model, we use images from the training set and add noise to them depending on the state they are supposed to represent at the given timestep. There are different approaches on how we determine how much noise should be added to an image on a specific timestep. We call those definitions **noise schedules**. If you're unsure about how the diffusion process works, check out the notebook `notebooks/diffusion.ipynb` to get the big picture of the approach, before you dive into the details. This implementation uses for all noise schedules a continuous-time noise schedule function $\gamma(t)$ and all noise schedules follow the same formula for noising an image:


$ x_t = \sqrt{\gamma(t)}I + \sqrt{1 - \gamma(t)}\epsilon$,


where $x_t$ is the noised image at timestep $t$, $I$ is the orignal image and $\epsilon~N(0,1)$ is noise sampled from the normal distribution. What differentiates individual noise schedules is the implementation of the gamma function $\gamma(t)$. You may ask yourself, why do we even use different noise schedules instead of just one specific one. Earlier (high noise) and later (lower noise) timesteps have different characteristics and depending on the chosen noise schedule, we can allocate more time to a specific range of timesteps. Broadly speaking, you can say that in the early timesteps the model defines the general structure of the image, i.e. the broad composition, while in the late timesteps details are refined. For this reason it can be beneficial to crank up the denoising difficulty (more noise added per step) in some timestep ranges, so we can extend the time in others where we care more about the outcome, e.g. detail refining. 
In this notebook, we will take a look at three noise schedules used in diffusion: the linear, cosine, and sigmoid noise schedule.

### Linear Noise Schedule
The linear noise schedule is the most straightforward one and more of a toy example for a noise schedule than something used in production. Usually when you read about the linear schedule it's referencing the version established here in [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239), which behaves completely different to our implementation above. As explained above, our implementation uses continuous noise schedules defined via gamma functions $\gamma(t)$ rather than the discrete beta schedules ($\beta_1, \beta_2 ..., beta_T$) that you see in the original papers.

The gamma function is simply  


$\gamma(t) = 1 - \frac{t}{T}$


with $T$ being the maximum amount of timesteps.

### Cosine Noise Schedule
The cosine noise schedule better preserves structure early, allows for finer detail refinement late, and allocates more capacity to the middle steps. The cosine noise schedule is often used as the baseline and is widely used in practice.
The original version was presented in [Improved Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2102.09672), but we use a modified version from [On the Importance of Noise Scheduling for Diï¬€usion Models](https://arxiv.org/abs/2301.10972) that allows for more fine-grained parametrization. By using the parameters `{"start": 0.0, "end": 1.0, "tau": 1.0}` you will receive the original version of the cosine noise schedule.

### Sigmoid Noise Schedule
The sigmoid noise schedule empirically shows more stable performance on higher resolutions than cosine noise schedules and allows for more flexible parametrization of transition sharpness. The original implementation of this noise schedule can be found in [Scalable Adaptive Computation for Iterative Generation](https://arxiv.org/abs/2212.11972).




Below, we will take a look how the individual gamma functions look like, what signal-to-noise ratio they create at different timesteps and get an intuitive feeling how they translate to a noise trajectory in the diffusion process.

In [None]:
fig, axes = plt.subplots(3, 11, figsize=(10, 3))


for row, noised_images, noise_schedule in zip(
        range(3), 
        [noised_images_linear, noised_images_cosine, noised_images_sigmoid], 
        ["Linear", "Cosine", "Sigmoid"],
    ):
    for col in range(TRAJECTORY_LEN):
        axes[row, col].imshow(noised_images[col])
        if row == 0:
            axes[row, col].set_title(f"t={ts_traj[col]}", fontsize=8, pad=2)

        if col == 0:
            axes[row, col].set_ylabel(noise_schedule, fontsize=8, rotation=0, labelpad=35)
            # Hide ticks, otherwise it messes up the layout
            axes[row, col].set_xticks([])
            axes[row, col].set_yticks([])
            axes[row, col].spines['top'].set_visible(False)
            axes[row, col].spines['right'].set_visible(False)
            axes[row, col].spines['bottom'].set_visible(False)
            axes[row, col].spines['left'].set_visible(False)

        if col != 0:
                axes[row, col].axis('off')

plt.suptitle("Trajectories for different noise schedules", fontsize=16)
plt.tight_layout()
plt.show()

## The Gamma Function
The gamma function describes how much of the signal, i.e. the image, is retained in each noising step and how quickly the signal-to-noise ratio changes from step to step. The most simple noise schedule is to define gamma as $\gamma = \frac{t}{T}$, but there are advantages to having other schedules. Using a cosine as the gamma function leads to a slower denoising in the early stages, while making larger steps towards the end when a lot of the structure is already defined and details worked out. In the plot below you see the gamma functions for different noise schedules. You can also check out how the cosine and sigmoid schedule changes, depending on the parameters `start`, `end`, and `tau`. You will find further informations about what those parameters are doing in the code (`diffusion/noise_schedulers.py`) or in this paper: [On the Importance of Noise Scheduling for Diffusion Models](https://arxiv.org/abs/2301.10972). A quick note to better understand the plots. We work with normalized timesteps, i.e. $t \in [0,1]$ with $t_0 = 0$ where we have the noise-free reference image and $T = 1$ where we only have noise in the diffusion process.

In [None]:
ts_plot = torch.linspace(0, 1, 1000)
gamma_linear = linear_noise_scheduler.gamma_func(ts_plot)
gamma_cosine = cosine_noise_scheduler.gamma_func(ts_plot)
gamma_sigmoid = sigmoid_noise_scheduler.gamma_func(ts_plot)

plt.plot(ts_plot, gamma_linear, label='Linear')
plt.plot(ts_plot, gamma_cosine, label='Cosine')
plt.plot(ts_plot, gamma_sigmoid, label='Sigmoid')
plt.title("Gamma functions of different noise schedules")
plt.xlabel('Timesteps (normalized)')
plt.ylabel('$\\gamma(t)$')
plt.legend()
plt.show()

## Signal-To-Noise Ratio
The output of the gamma function for a specific time step isn't the used noise factor yet. We know that when we add two independent Gaussian random variables, the resulting variance is the sum of the variances of the two. This would mean that by adding more and more noise to an image, the variance of the noised image would continually increase. We want to avoid that, since we want a consistent scale for the model across timesteps.
To achieve that, we use $ x_t = \sqrt{\gamma_t}I + \sqrt{1 - \gamma_t}\epsilon$, instead of a linear combination of the values, e.g. $x_t = \gamma_tI + (1 - \gamma_t)\epsilon$, where $\epsilon ~ N(0,1)$, $\gamma_t$ is the output of the gamma function of the noise scheduler at timestep $t$, and $I$ is the normalized image. The use of the square root for the factor is due to how a scaling factor affects the variance of a Gaussian distribution:


$X \sim N(\mu, \sigma^2)\Rightarrow aX \sim N(a\mu, a^2\sigma^2), a \in \mathbb{R}$



You can check out the Signal-To-Noise ratios (SNR) for the corresponding gamma functions below. A high SNR means that a lot of the signal, i.e. image, is contained in the output.

In [None]:
snr_linear = torch.sqrt(gamma_linear / (1 - gamma_linear))
snr_cosine = torch.sqrt(gamma_cosine / (1 - gamma_cosine))
snr_sigmoid = torch.sqrt(gamma_sigmoid / (1 - gamma_sigmoid))

plt.semilogy(ts_plot, snr_linear, label='Linear')
plt.semilogy(ts_plot, snr_cosine, label='Cosine')
plt.semilogy(ts_plot, snr_sigmoid, label='Sigmoid')
plt.title("SNR over timesteps for different noise schedules")
plt.xlabel('Timesteps (normalized)')
plt.ylabel('Signal-To-Noise Ratio')
plt.legend()
plt.show()

## Image Size Dependency

When you work with higher resolution images, you will eventually notice that despite using the same noise schedule, the results looks fairly different (check the figure below). It feels like that the perceived noise level in higher image resolutions isn't as high as in the lower size images. The reason for this is that images with higher resolutions have a higher pixel redundancy, which means in simple terms that it's easier to guess a value of a pixel from its neighbours than in smaller resolutions. This isn't just affecting our perception of the problem, but makes it also easier for the model to guess the noise.

In [None]:
IMAGE_SIZES = [32, 64, 128, 256, 512, 1024]
GAMMA = 0.5

images = []
for img_size in IMAGE_SIZES:
    img = get_picsum_image(IMAGE_ID, img_size)
    img, mean, std = convert_image_to_normalized_tensor(img)
    noise = torch.randn_like(img)
    noised_image = GAMMA**.5 * img + (1 - GAMMA)**.5 * noise
    images += convert_normalized_tensors_to_images(noised_image, mean, std)

fig, axes = plt.subplots(1, len(IMAGE_SIZES), figsize=(len(IMAGE_SIZES)*2, 3))
axes = axes.flatten()

for idx, (img_size, img) in enumerate(zip(IMAGE_SIZES, images)):
    axes[idx].imshow(img)
    axes[idx].set_title(f"{img_size}x{img_size}", fontsize=14)
    axes[idx].axis('off')

plt.suptitle("Same SNR at different image resolutions ($\\gamma_t = 0.5$)", fontsize=16)
plt.tight_layout()
plt.show()

## Input Scale Factor
An obvious solution to the problem would be a hyperparameter search for the noise schedules, but there's a simpler way to counteract this problem in higher resolutions: We simply decrease the signal during the noising step. We do this by adding an input scale factor $b$ that we add to the noising equation: 

$ x_t = \sqrt{\gamma_t}bI + \sqrt{1 - \gamma_t}\epsilon$

With this, we compensate for the pixel redundancy and improve image generation performance of the trained model. Check out below how images with different resolutions and input scaling factors look like and get an intuitive feeling how $b$ affects the information contained in the image. Remember, $b=0$ means no information is retained at all and $b=1$ means the full image is used in the noising process.

In [None]:
IMAGE_SIZES = [32, 64, 128, 256, 512, 1024]
GAMMA = 0.5
SCALING_FACTORS = [0, .2, .4, .6, .8, 1]

original_images = [convert_image_to_normalized_tensor(get_picsum_image(IMAGE_ID, img_size)) 
                   for img_size in IMAGE_SIZES]
images = []
for b in SCALING_FACTORS:
    for img, mean, std in original_images:
        noise = torch.randn_like(img)
        noised_image = GAMMA**.5 * b * img + (1 - GAMMA)**.5 * noise
        images += convert_normalized_tensors_to_images(noised_image, mean, std)

fig, axes = plt.subplots(len(SCALING_FACTORS), len(IMAGE_SIZES), figsize=(len(IMAGE_SIZES)*2, len(SCALING_FACTORS)*2))
for row, b in enumerate(SCALING_FACTORS):
    for col, img_size in enumerate(IMAGE_SIZES):
        idx = row * len(IMAGE_SIZES) + col
        axes[row, col].imshow(images[idx])
        if row == 0:
            axes[row, col].set_title(f"{img_size}x{img_size}", fontsize=8, pad=2)

        if col == 0:
            axes[row, col].set_ylabel(f"$b={b}$", fontsize=8, rotation=0, labelpad=35)
            # Hide ticks, otherwise it messes up the layout
            axes[row, col].set_xticks([])
            axes[row, col].set_yticks([])
            axes[row, col].spines['top'].set_visible(False)
            axes[row, col].spines['right'].set_visible(False)
            axes[row, col].spines['bottom'].set_visible(False)
            axes[row, col].spines['left'].set_visible(False)

        if col != 0:
                axes[row, col].axis('off')

plt.suptitle("Same SNR at different image resolutions and scale factors ($\\gamma_t = 0.5$)", fontsize=16)
plt.tight_layout()
plt.show()