## Installation

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

## Example - training an unconditional diffusion model

### Dataset

First we need a dataset containing samples whose distribution is to be modeled by the diffusion model. This is application dependent so below we use a pair of dummy train and evaluation dataloaders, which should be replaced with realistic ones for your specific use case.

Our code setup accepts any Python Iterable objects to be used as dataloaders. The expectation is that they should continuously yield a dictionary with a field named `x` whose corresponding value is a numpy array with shape `(batch, *spatial_dims, channels)`.

In [None]:
import itertools
import numpy as np

In [None]:
# A batch with 8 2D samples (with spatial dim of 64x64 and 1 channel)
fake_batch = {"x": np.ones((8, 64, 64, 1))}

train_dataloader = eval_dataloader = itertools.repeat(fake_batch)

### Architecture

Next let's define the U-Net backbone. The "Preconditioning" is merely to ensure that the inputs and outputs of the network are roughly standardized (for more details, see Appendix B.6. in [this paper](https://arxiv.org/abs/2206.00364)).

In [None]:
from swirl_dynamics.lib.diffusion import unets

In [None]:
denoiser_model = unets.PreconditionedDenoiser(
    out_channels=1,
    num_channels=(64, 128, 256),
    downsample_ratio=(2, 2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=False,
    num_heads=8,
    sigma_data=0.25,  # standard deviation of the entire dataset
)

### Training

In [None]:
import jax
import jax.numpy as jnp
import optax
from orbax import checkpoint

from swirl_dynamics.lib.diffusion import diffusion
from swirl_dynamics.projects.probabilistic_diffusion import unconditional
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train

For diffusion model training, the above-defined U-Net backbone serves as a denoiser, which takes as input a batch of (isotropic Gaussian noise) corrupted samples and outputs its best guess for what the uncorrupted image would be.

Besides the backbone architecture, we also need to specify how to sample the noise levels (i.e. standard deviations) used to corrupt the samples and the weighting for each noise level in the loss function (for other options and configurations, see [`swirl_dynamics.lib.diffusion.diffusion`](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/lib/diffusion/diffusion.py)):

In [None]:
diffusion_scheme = diffusion.Diffusion.create_variance_exploding(
    sigma=diffusion.tangent_noise_schedule(),
    data_std=1.0,
)
model = unconditional.DenoisingModel(
    input_shape=(64, 64, 1),  # this must agree with the expected sample shape (without the batch dimension)
    denoiser=denoiser_model,
    noise_sampling=diffusion.log_uniform_sampling(
        diffusion_scheme, clip_min=1e-4, uniform_grid=True,
    ),
    noise_weighting=diffusion.edm_weighting(data_std=1.0),
)

We are now ready to define the learning parameters.

In [None]:
# !rm -R -f $workdir  # optional: clear the working directory

In [None]:
num_train_steps = 10000  #@param
workdir = "/tmp/diffusion_demo"  #@param
initial_lr = 0.0  #@param
peak_lr = 1e-4  #@param
warmup_steps = 1000  #@param
end_lr = 1e-6  #@param
ema_decay = 0.999  #@param
ckpt_interval = 1000  #@param
max_ckpt_to_keep = 5  #@param

To start training, we first need to initialize the trainer.

In [None]:
# NOTE: use `unconditional.DistributedDenoisingTrainer` for multi-device
# training with data parallelism
trainer = unconditional.DenoisingTrainer(
    model=model,
    rng=jax.random.PRNGKey(888),
    optimizer=optax.adam(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=initial_lr,
            peak_value=peak_lr,
            warmup_steps=warmup_steps,
            decay_steps=num_train_steps,
            end_value=end_lr,
        ),
    ),
    # We keep track of an exponential moving average of the model parameters
    # over training steps. This alleviates the "color-shift" problems known to
    # exist in the diffusion models.
    ema_decay=ema_decay,
)

Now we are ready to kick start training. A couple of "callbacks" are passed to assist with monitoring and checkpointing.

The first step will be a little slow as Jax needs to JIT compile the step function (the same goes for the first step where evaluation is performed). Fortunately, steps after that should continue much faster.

In [None]:
train.run(
    train_dataloader=train_dataloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=num_train_steps,
    metric_aggregation_steps=20,
    eval_dataloader=eval_dataloader,
    eval_every_steps = 1000,
    num_batches_per_eval = 2,
    callbacks=(
        # This callback displays the training progress in a tqdm bar
        callbacks.TqdmProgressBar(
            total_train_steps=num_train_steps,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically
        callbacks.TrainStateCheckpoint(
            base_dir=workdir,
            options=checkpoint.CheckpointManagerOptions(
                save_interval_steps=ckpt_interval, max_to_keep=max_ckpt_to_keep
            ),
        ),
    ),
)

### Inference

#### Unconditional generation

The trained denoiser may be used to generate unconditional samples.

First, let's try to restore the model from checkpoint.

In [None]:
# Restore train state from checkpoint. By default, the move recently saved
# checkpoint is restored. Alternatively, one can directly use
# `trainer.train_state` if continuing from the training section above.
trained_state = unconditional.TrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)
# Construct the inference function
denoise_fn = unconditional.DenoisingTrainer.inference_fn_from_state_dict(
    trained_state, use_ema=True, denoiser=denoiser_model
)

Diffusion samples are generated by plugging the trained denoising function in a stochastic differential equation (parametrized by the diffusion scheme) and solving it backwards in time.

In [None]:
from swirl_dynamics.lib.diffusion import samplers
from swirl_dynamics.lib.solvers import sde

In [None]:
sampler = samplers.SdeSampler(
    input_shape=(64, 64, 1),
    integrator=sde.EulerMaruyama(),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
)

In [None]:
# Optional: JIT compile the generate function so that it runs faster if
# repeatedly called.
generate = jax.jit(sampler.generate, static_argnums=(2,))

In [None]:
# Time steps for the SDE solver
tspan = samplers.exponential_noise_decay(
    scheme=diffusion_scheme, num_steps=256, end_sigma=1e-3
)
samples, aux = generate(
    rng=jax.random.PRNGKey(88), tspan=tspan, num_samples=4
)

In the output, `samples` is the generated samples and `aux` is the auxiliary output from the generation process. It contains the full trajectory of the SDE, which may be probed to better understand the generation behaviors.

In [None]:
print(samples.shape)
print(aux["trajectories"].shape)

Visualize the generated samples

In [None]:
import matplotlib.pyplot as plt

In [None]:
# Plot generated samples
vmin, vmax = -3, 3

fig, ax = plt.subplots(1, 4, figsize=(10, 2))
for i in range(4):
  im = ax[i].imshow(samples[i, :, :, 0], vmin=vmin, vmax=vmax)
  fig.colorbar(im, ax=ax[i])

plt.tight_layout()
plt.show()

In [None]:
# Plot SDE trajectory
steps = 8
sample_id = 0
vmin, vmax = -3, 3

fig, ax = plt.subplots(1, steps, figsize=(steps * 2.5, 2))
for i in range(steps):
  step_idx = i * (aux["trajectories"].shape[0] // steps)
  im = ax[i].imshow(
      aux["trajectories"][step_idx, sample_id, :, :, 0], vmin=vmin, vmax=vmax
  )
  ax[i].set_title(f"diffusion time {tspan[step_idx]: .3f}")
  fig.colorbar(im, ax=ax[i])

plt.tight_layout()
plt.show()

#### A-posteriori guided generation

We may post-process a trained denoising function to perform "guided" generation. Below we provide an example for a super-resolution task: generating high-resolution images given a low-resolution one.

To achieve this, we provide the low-resolution image as the guide input and post-process the denoiser to favor generating samples which, when downsampled, give values close to these guide input.

In [None]:
from swirl_dynamics.lib.diffusion import guidance

In [None]:
guidance_fn = guidance.InfillFromSlices(
    # This specifies location of the guide input using python slices.
    # Here it implies that the guide input corresponds to pixels at 0, 8, ...
    slices=(slice(None), slice(None, None, 8), slice(None, None, 8)),

    # This is a parameter that controls how "hard" the denoiser pushes for
    # the conditioning to be satisfied. At higher values, the conditioning is
    # better satisfied in exchange for sample diversity.
    guide_strength=0.1,
)

In [None]:
guided_sampler = samplers.SdeSampler(
    input_shape=(64, 64, 1),
    integrator=sde.EulerMaruyama(),
    scheme=diffusion_scheme,
    denoise_fn=denoise_fn,
    guidance_fn=guidance_fn,
)

guided_generate = jax.jit(guided_sampler.generate, static_argnums=(2,))

In [None]:
guided_samples, _ = guided_generate(
    rng=jax.random.PRNGKey(66),
    tspan=samplers.exponential_noise_decay(
        scheme=diffusion_scheme, num_steps=128, end_sigma=1e-3
    ),
    num_samples=4,
    # The shape of the guidance input must be compatible with
    # `sample[guidance_fn.slices]`
    guidance_input=jnp.ones((1, 8, 8, 1)),
)

Visualize guided samples

In [None]:
import matplotlib.pyplot as plt

In [None]:
vmin, vmax = -3, 3

fig, ax = plt.subplots(1, 4, figsize=(10, 2))
for i in range(4):
  im = ax[i].imshow(guided_samples[i, :, :, 0], vmin=vmin, vmax=vmax)
  fig.colorbar(im, ax=ax[i])

plt.tight_layout()
plt.show()