This repository contains a torch/luz implementation of the Denoising Diffusion Implicit Models. Code in this repository is heavily influenced by code in Béres (2022) which is mostly based on (Song, Meng, and Ermon 2020) with a few ideas coming from (Nichol and Dhariwal 2021) and (Karras et al. 2022).
Denoising Diffusion models are inspired by non-equilibrium thermodynamics (Sohl-Dickstein et al. 2015). First a forward diffusion algorithm is defined, this procedure converts any complex data distribution into a simple tractable distribution. We then learn a procedure to reverse the diffusion process.
While there’s a strong theory foundation for denoising diffusion models, in practice, the core component is a neural network capable of separating a noisy image in its image and noise parts. For sampling new images we can then take pure noise and successively ‘denoise’ it until it’s just the image part.
Originaly the forward diffusion process has been defined as a Markov
process that successively (for eg. for
One important property of this process is that you can easily sample
from
The diffusion rate
In Song, Meng, and Ermon (2020) and Ho, Jain, and Abbeel (2020), a so
called linear schedule is used. However the linearity is happening on
Below we can visualize the forward diffusion process with both the linear and cosine scheduling for 10 diffusion steps.
Samples from q(xt|x0) for linearly spaced values of t with linear schedule (top) and the cosine schedule (bottom). We can see that with the linear schedule, images become almost pure noise after the first half - this seems to interfere in model performance according to Nichol and Dhariwal (2021) .Also in practice, we never let
The forward diffusion process thus is a fixed procedure to generate
samples of
As
The intuition though is that, we have
Thus, if we can estimate
There’s another way to get an estimate
Although denoising diffusion models can work with whatever kind of data
distribution
Since we are dealing with images, we will use neural network
architectures that are adapted for that domain. The neural network we
are building takes a noisy image sampled from
The core part of the neural network is a U-Net (Ronneberger, Fischer,
and Brox 2015), which is a common architecture in domains where both the
input and the output are image-like tensors. The U-Net takes as input a
concatenation of
A sinusoidal embedding (Vaswani et al. 2017) is used to encode the diffusion times (or the noise variance) into the model. The visualization below shows how diffusion times are mapped to the embedding - assuming the dimension size of 32. Each row is a embedding vector given the diffusion time. Sinusoidal embedding have nice properties, like preserving the relative distances (Kazemnejad 2019).
You can find the code dor the sinusoidal embedding in diffusion.R
.
A U-Net is a convolutional neural network that successively downsamples the image resolution while increasing its depth. After a few downsampling blocks, it starts upsampling the representation and decreasing the channels depth. The main idea in U-Net, is that much like Residual Networks, the upsampling blocks take as input both the representation from the previous upsampling block and the representation from a previous downsampling block.
Unet modelUnlike the original U-Net implementation, we use ResNet Blocks (He et al. 2015) in the downsampling and upsampling blocks of the U-Net. Each down or upsampling blocks contain block_depth of those ResNet blocks. We also use the Swish activation function.
(Song, Meng, and Ermon 2020) and (Ho, Jain, and Abbeel 2020) also use an
attention layer at lower resolutions, but we didn’t for simplicity. The
code for the U-Net can be found in unet.R
.
Given the definitions of unet
and the sinusoidal_embedding
we can
implement the diffusion model with:
diffusion <- nn_module(
initialize = function(image_size, embedding_dim = 32, widths = c(32, 64, 96, 128),
block_depth = 2) {
self$unet <- unet(2*embedding_dim, embedding_dim, widths = widths, block_depth)
self$embedding <- sinusoidal_embedding(embedding_dim = embedding_dim)
self$conv <- nn_conv2d(image_size[1], embedding_dim, kernel_size = 1)
self$upsample <- nn_upsample(size = image_size[2:3])
self$conv_out <- nn_conv2d(embedding_dim, image_size[1], kernel_size = 1)
# we initialize at zeros so the initial output of the network is also just zeroes
purrr::walk(self$conv_out$parameters, nn_init_zeros_)
},
forward = function(noisy_images, noise_variances) {
embedded_variance <- noise_variances |>
self$embedding() |>
self$upsample()
embedded_image <- noisy_images |>
self$conv()
unet_input <- torch_cat(list(embedded_image, embedded_variance), dim = 2)
unet_output <- unet_input %>%
self$unet() %>%
self$conv_out()
}
)
With the default hyper-parameters here’s the diffusion model summary:
diffusion(image_size = c(3, 64, 64))
An `nn_module` containing 3,562,211 parameters.
── Modules ─────────────────────────────────────────────────────────────────────
• unet: <nn_module> #3,561,984 parameters
• embedding: <nn_module> #0 parameters
• conv: <nn_conv2d> #128 parameters
• upsample: <nn_upsample> #0 parameters
• conv_out: <nn_conv2d> #99 parameters
The model is trained to learn
In luz, since this is not a conventional supervised model, we implement
the training logic in the step()
method of the diffusion module, which
looks like:
diffusion_model <- nn_module(
... # other method omitted for chose focus
step = function() {
# images are standard normalized
ctx$input <- images <- ctx$model$normalize(ctx$input)
# sample random diffusion times and use the scheduler to obtain the amount
# of variance that should be applied for each t
diffusion_times <- torch_rand(images$shape[1], 1, 1, 1, device = images$device)
rates <- self$diffusion_schedule(diffusion_times)
# forward diffusion - generate x_t
noises <- torch_randn_like(images)
images <- rates$signal * images + rates$noise * noises
# 'denoises' the noisy images.
# creates predictions for the image and noise part
ctx$pred <- ctx$model(images, rates)
loss <- self$loss(noises, ctx$pred$pred_noises)
# this `step()` method is also applied during validation.
# we only want to backprop during training though
if (ctx$training) {
ctx$opt$zero_grad()
loss$backward()
ctx$opt$step()
}
# saves the loss for correctly reporting metrics in luz
ctx$loss[[ctx$opt_name]] <- loss$detach()
}
...
)
We use GuildAI for training configuration. The
guildai automatically parses the
train.R
script and allow us to change any scalar value defined in the
top level of the file.
For instance, you can run the training script with the default values using:
guildai::guild_run("train.R")
You can supply different hyperparameter values by using the flags
argument, eg:
guildai::guild_run("train.R", flags = list(batch_size = 128))
Besides allowing configuring the training hyper-parameters, GuildAI also
tracks all experiment files and results. Since we are passing
luz_callback_tfevents()
to fit
when training the diffusion model,
all metrics will also be logged and can be visualized in the Guild
View dashboard and accessed using
guildai::runs_info()
.
We ran experiments for two different datasets:
- Oxford pets
- Oxford flowers
For each dataset we experimented with 3 configuration options, leaving all other hyper-parameters fixed with the default values. Those were:
loss
function : ‘mae’ and ‘mse’loss_on
: ‘noise’ and ‘image’ - wether the model is predicting the images or noise values.schedule_type
: ‘linear’ or ‘cosine’
Taking advantage of GuildAI integration, experiments can be run with:
guildai::guild_run("train.R", flags = list(
dataset_name = c("pets", "flowers"),
loss = c("mae", "mse"),
loss_on = c("noise", "image"),
schedule_type = c("linear", "cosine"),
num_workers = 8
))
We evaluated the KID (Bińkowski et al. 2018) on the final model so as to compare the quality of generated samples.
The results for the flowers dataset are shown below:
runs <- guildai::runs_info(label = "full_exp1") %>%
dplyr::filter(flags$dataset_name == "flowers")
runs |>
tidyr::unpack(c(flags, scalars)) |>
dplyr::select(
loss = loss,
loss_on = loss_on,
schedule_type = schedule_type,
noise_loss,
image_loss,
kid
) %>%
knitr::kable()
loss | loss_on | schedule_type | noise_loss | image_loss | kid |
---|---|---|---|---|---|
mse | image | cosine | 0.1714205 | 0.2668262 | 0.1390257 |
mse | image | linear | 0.1472155 | 0.3130986 | 0.2125010 |
mse | noise | cosine | 0.1655174 | 0.2694866 | 0.1315258 |
mse | noise | linear | 0.1399971 | 0.3131811 | 0.1652256 |
mae | image | cosine | 0.1665038 | 0.2591124 | 0.1038467 |
mae | image | linear | 0.1435601 | 0.3056073 | 0.1912024 |
mae | noise | cosine | 0.1644953 | 0.2626610 | 0.0934186 |
mae | noise | linear | 0.1388959 | 0.3046939 | 0.1673736 |
We can see that given that the other hyper-parameters are fixed, it’s better to train the neural network on the MAE of the noises using a cosine schedule. Below we compare images generated for each different run. The ordering is the same as the table above, so you can visualize the effect of different values of KID.
# we implement it at as a function so we can reuse for the pets dataset.
plot_samples_from_runs <- function(runs) {
images <- runs$dir |>
lapply(function(x) {
model <- luz::luz_load(file.path(x, "luz_model.luz"))
with_no_grad({
model$model$eval()
model$model$generate(8, diffusion_steps=20)
})
})
images |>
lapply(function(x) torch::torch_unbind(x)) |>
unlist() |>
plot_tensors(ncol = 8, denormalize = identity)
}
plot_samples_from_runs(runs)
Below we also show the results fro the Oxford pets dataset:
pet_runs <- guildai::runs_info(label = "full_exp1") %>%
dplyr::filter(flags$dataset_name == "pets")
pet_runs |>
tidyr::unpack(c(flags, scalars)) |>
dplyr::select(
loss = loss,
loss_on = loss_on,
schedule_type = schedule_type,
noise_loss,
image_loss,
kid
) %>%
knitr::kable()
loss | loss_on | schedule_type | noise_loss | image_loss | kid |
---|---|---|---|---|---|
mse | image | cosine | 0.1682112 | 0.2542121 | 0.1500503 |
mse | image | linear | 0.1472592 | 0.3009916 | 0.1741628 |
mse | noise | cosine | 0.1630672 | 0.2554463 | 0.1886360 |
mse | noise | linear | 0.1390071 | 0.2988549 | 0.1187079 |
mae | image | cosine | 0.1649045 | 0.2497923 | 0.1487082 |
mae | image | linear | 0.1444500 | 0.2968645 | 0.2122481 |
mae | noise | cosine | 0.1644861 | 0.2545764 | 0.1971107 |
mae | noise | linear | 0.1390262 | 0.2952372 | 0.1810841 |
Again, the images below are ordered the same way as the above table, each row representing a different experiment configuration.
plot_samples_from_runs(pet_runs)
Even though the results for the second dataset are not really impressive, using larger models is likely to improve the quality of the generated images.
We confirm that most of the practical decisions from the literature, like reducing the noise loss instead of the image loss, using a cosine schedule instead of the linear schedule can be reproduced and shown to be better than the alternatives.
Images can be sampled from the model using the generate
method.
Remember to always set the model into eval()
mode before sampling, so
the batch normal layers are correctly applied.
box::use(torch[...])
box::use(./callbacks[plot_tensors])
fitted <- luz::luz_load(file.path(runs$dir[which.min(runs$scalars$kid)], "luz_model.luz"))
with_no_grad({
fitted$model$eval()
x <- fitted$model$generate(36, diffusion_steps = 25)
})
plot_tensors(x)
Bansal, Arpit, Eitan Borgnia, Hong-Min Chu, Jie S. Li, Hamid Kazemi, Furong Huang, Micah Goldblum, Jonas Geiping, and Tom Goldstein. 2022. “Cold Diffusion: Inverting Arbitrary Image Transforms Without Noise.” https://doi.org/10.48550/ARXIV.2208.09392.
Béres, András. 2022. “Denoising Diffusion Implicit Models.” https://keras.io/examples/generative/ddim/.
Bińkowski, Mikołaj, Danica J. Sutherland, Michael Arbel, and Arthur Gretton. 2018. “Demystifying MMD GANs.” https://doi.org/10.48550/ARXIV.1801.01401.
He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. 2015. “Deep Residual Learning for Image Recognition.” https://doi.org/10.48550/ARXIV.1512.03385.
Ho, Jonathan, Ajay Jain, and Pieter Abbeel. 2020. “Denoising Diffusion Probabilistic Models.” https://doi.org/10.48550/ARXIV.2006.11239.
Karras, Tero, Miika Aittala, Timo Aila, and Samuli Laine. 2022. “Elucidating the Design Space of Diffusion-Based Generative Models.” https://doi.org/10.48550/ARXIV.2206.00364.
Kazemnejad, Amirhossein. 2019. “Transformer Architecture: The Positional Encoding.” Kazemnejad.com. https://kazemnejad.com/blog/transformer_architecture_positional_encoding/.
Nichol, Alex, and Prafulla Dhariwal. 2021. “Improved Denoising Diffusion Probabilistic Models.” https://doi.org/10.48550/ARXIV.2102.09672.
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. 2015. “U-Net: Convolutional Networks for Biomedical Image Segmentation.” https://doi.org/10.48550/ARXIV.1505.04597.
Sohl-Dickstein, Jascha, Eric A. Weiss, Niru Maheswaranathan, and Surya Ganguli. 2015. “Deep Unsupervised Learning Using Nonequilibrium Thermodynamics.” https://doi.org/10.48550/ARXIV.1503.03585.
Song, Jiaming, Chenlin Meng, and Stefano Ermon. 2020. “Denoising Diffusion Implicit Models.” https://doi.org/10.48550/ARXIV.2010.02502.
Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” https://doi.org/10.48550/ARXIV.1706.03762.
Weng, Lilian. 2021. “What Are Diffusion Models?” Lilianweng.github.io, July. https://lilianweng.github.io/posts/2021-07-11-diffusion-models/.
Footnotes
-
(Bansal et al. 2022) seems to show that any lossy image transformation works ↩