# Progressive Distillation for Fast Sampling of Diffusion Models
Code for the <a href="https://openreview.net/forum?id=TIdIXIpzhoI">ICLR 2022 paper</a> by Tim Salimans and Jonathan Ho.
Model checkpoints to follow soon.

Make sure to use a TPU when running this notebook, enabled via Runtime -> Change runtime type -> Hardware accelerator

<a href="https://colab.research.google.com/github/google-research/google_research/diffusion_distillation/blob/master/diffusion_distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**abstract**:
Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.

This notebook is intended as an easy way to get started with the Progressive Distillation algorithm. Reproducing the results from the paper exactly can be done using the hyperparameters in the provided config files, but this requires running at a larger scale and for longer than is practical in a notebook. We hope to be able to release the checkpoints for the trained model at a later time.

![FID vs number of steps](../fid_steps_graph.png)

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# Download the diffusion_distillation repository 
!apt-get -qq install subversion
!svn checkout https://github.com/google-research/google-research/trunk/diffusion_distillation
!pip install -r diffusion_distillation/requirements.txt --quiet

Checked out revision 8445.


In [None]:
import os
import requests
import functools
import jax
from jax.config import config
import jax.numpy as jnp
import flax
from matplotlib import pyplot as plt
import numpy as onp
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
from diffusion_distillation import diffusion_distillation

In [None]:
# configure JAX to use the TPU
if 'TPU_DRIVER_MODE' not in globals():
  url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver_nightly'
  resp = requests.post(url)
  TPU_DRIVER_MODE = 1
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
print(config.FLAGS.jax_backend_target)

grpc://10.115.199.154:8470


## Train a new diffusion model

In [None]:
# create model
config = diffusion_distillation.config.cifar_base.get_config()
model = diffusion_distillation.model.Model(config)

In [None]:
# init params 
state = jax.device_get(model.make_init_state())
state = flax.jax_utils.replicate(state)

In [None]:
# JIT compile training step
train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

In [None]:
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

In [None]:
# run training
for step in range(10):
  batch = next(train_iter)
  state, metrics = train_step(state, batch)
  metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
  metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
  print(metrics)

{'train/gnorm': 1.4675863981246948, 'train/loss': 0.39701905846595764}
{'train/gnorm': 1.002144455909729, 'train/loss': 0.23647935688495636}
{'train/gnorm': 1.164994239807129, 'train/loss': 0.1822689324617386}
{'train/gnorm': 0.8901512026786804, 'train/loss': 0.1539679318666458}
{'train/gnorm': 0.9623289108276367, 'train/loss': 0.15654230117797852}
{'train/gnorm': 0.7790379524230957, 'train/loss': 0.1380912959575653}
{'train/gnorm': 0.7820743322372437, 'train/loss': 0.1483442783355713}
{'train/gnorm': 0.6346055865287781, 'train/loss': 0.13180819153785706}
{'train/gnorm': 0.8465785980224609, 'train/loss': 0.13770630955696106}
{'train/gnorm': 0.719518780708313, 'train/loss': 0.12717439234256744}


## Distill a trained diffusion model

In [None]:
# create model
config = diffusion_distillation.config.cifar_distill.get_config()
model = diffusion_distillation.model.Model(config)

In [None]:
# load the teacher params: todo
# model.load_teacher_state(config.distillation.teacher_checkpoint_path)

In [None]:
# init student state
init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
optim = model.make_optimizer_def().create(init_params)
state = diffusion_distillation.model.TrainState(
    step=model.teacher_state.step,
    optimizer=optim,
    ema_params=diffusion_distillation.utils.copy_pytree(init_params),
    num_sample_steps=model.teacher_state.num_sample_steps//2)

In [None]:
# build input pipeline
total_bs = config.train.batch_size
device_bs = total_bs // jax.device_count()
train_ds = model.dataset.get_shuffled_repeated_dataset(
    split='train',
    batch_shape=(
        jax.local_device_count(),  # for pmap
        config.train.substeps,  # for lax.scan over multiple substeps
        device_bs,  # batch size per device
    ),
    local_rng=jax.random.PRNGKey(0),
    augment=True)
train_iter = diffusion_distillation.utils.numpy_iter(train_ds)

In [None]:
steps_per_distill_iter = 10  # number of distillation steps per iteration of progressive distillation
end_num_steps = 4  # eventual number of sampling steps we want to use 
while state.num_sample_steps >= end_num_steps:

  # compile training step
  train_step = functools.partial(model.step_fn, jax.random.PRNGKey(0), True)
  train_step = functools.partial(jax.lax.scan, train_step)  # for substeps
  train_step = jax.pmap(train_step, axis_name='batch', donate_argnums=(0,))

  # train the student against the teacher model
  print('distilling teacher using %d sampling steps into student using %d steps'
        % (model.teacher_state.num_sample_steps, state.num_sample_steps))
  state = flax.jax_utils.replicate(state)
  for step in range(steps_per_distill_iter):
    batch = next(train_iter)
    state, metrics = train_step(state, batch)
    metrics = jax.device_get(flax.jax_utils.unreplicate(metrics))
    metrics = jax.tree_map(lambda x: float(x.mean(axis=0)), metrics)
    print(metrics)

  # student becomes new teacher for next distillation iteration
  model.teacher_state = jax.device_get(
      flax.jax_utils.unreplicate(state).replace(optimizer=None))

  # reset student optimizer for next distillation iteration
  init_params = diffusion_distillation.utils.copy_pytree(model.teacher_state.ema_params)
  optim = model.make_optimizer_def().create(init_params)
  state = diffusion_distillation.model.TrainState(
      step=model.teacher_state.step,
      optimizer=optim,
      ema_params=diffusion_distillation.utils.copy_pytree(init_params),
      num_sample_steps=model.teacher_state.num_sample_steps//2)

## Load a distilled model checkpoint and sample from it

In [None]:
# list all available distilled checkpoints
# TODO: use cloud bucket in public version

In [None]:
# create imagenet model
config = diffusion_distillation.config.imagenet64_base.get_config()
model = diffusion_distillation.model.Model(config)

In [None]:
# load distilled checkpoint for 8 sampling steps
loaded_params = diffusion_distillation.checkpoints.restore_from_path('/todo/imagenet_8', target=None)['ema_params']

In [None]:
# fix possible flax version errors
ema_params = jax.device_get(model.make_init_state()).ema_params
loaded_params = flax.core.unfreeze(loaded_params)
loaded_params = jax.tree_map(
    lambda x, y: onp.reshape(x, y.shape) if hasattr(y, 'shape') else x,
    loaded_params,
    flax.core.unfreeze(ema_params))
loaded_params = flax.core.freeze(loaded_params)
del ema_params

In [None]:
# sample from the model
imagenet_classes = {'malamute': 249, 'siamese': 284, 'great_white': 2,
                    'speedboat': 814, 'reef': 973, 'sports_car': 817,
                    'race_car': 751, 'model_t': 661, 'truck': 867}
labels = imagenet_classes['sports_car'] * jnp.ones((16,), dtype=jnp.int32)
samples = jax.device_get(model.samples_fn(rng=jax.random.PRNGKey(0), labels=labels, params=loaded_params, num_steps=8)).astype(onp.uint8)

In [None]:
# visualize samples
padded_samples = onp.pad(samples, ((0,0), (1,1), (1,1), (0,0)), mode='constant', constant_values=255)
nrows = int(onp.sqrt(padded_samples.shape[0]))
ncols = padded_samples.shape[0]//nrows
_, height, width, channels = padded_samples.shape
img_grid = padded_samples.reshape(nrows, ncols, height, width, channels).swapaxes(1,2).reshape(height*nrows, width*ncols, channels)
img = plt.imshow(img_grid)
plt.axis('off')