# Training Multimodal MNIST 🚀

In this colab we showcase how to train a diffusion model on MNIST dataset. This
colab can run on any colab backend.

In [None]:
################################################################################
# Common modules
################################################################################

import dataclasses
import functools
from typing import Protocol
from etils import ecolab
from flax import linen as nn
import grain.python as pygrain
import jax
import jax
import jax.numpy as jnp
from jaxtyping import PyTree  # pylint: disable=g-multiple-import,g-importing-member
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow_datasets as tfds
import tqdm


################################################################################
# Hackable diffusion modules
################################################################################

cell_autoreload = True  # @param{type: "boolean"}

with ecolab.adhoc(
    reload=["hackable_diffusion"],
    invalidate=False,
    cell_autoreload=cell_autoreload,
):
  from hackable_diffusion import hd

In [None]:
diffusion_network = hd.diffusion_network
hd_typing = hd.hd_typing
utils = hd.utils
time_sampling = hd.time_sampling
base_process = hd.corruption.base
gaussian_process = hd.corruption.gaussian
discrete_process = hd.corruption.discrete
schedules = hd.corruption.schedules
arch_typing = hd.architecture.arch_typing
conditioning_encoder = hd.architecture.conditioning_encoder
discrete_backbone = hd.architecture.discrete
unet = hd.architecture.unet
diffusion_inference = hd.inference.diffusion_inference
wrappers = hd.inference.wrappers
gaussian_loss = hd.loss.gaussian
discrete_loss = hd.loss.discrete
base_loss = hd.loss.base
time_scheduling = hd.sampling.time_scheduling
sampling = hd.sampling.sampling
gaussian_step_sampler = hd.sampling.gaussian_step_sampler
discrete_step_sampler = hd.sampling.discrete_step_sampler
base_sampler = hd.sampling.base

# Prepare MNIST data

Create py-grain data structures for convenient batching and loading.

MNIST data is $28 \times 28 \times 3$ (the images are scaled between $-1.0$ and
$1.0$).

In [None]:
@dataclasses.dataclass(frozen=True)
class PreprocessExample(pygrain.MapTransform):
  """Preprocesses an example."""

  def map(self, x: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Rescales image values to [-1, 1] and converts labels to int32."""

    image = x['image']

    image_discrete = image.astype(np.int32)
    image_discrete = np.reshape(image_discrete, (28, 28, 1))
    # We add additional dimension for the tokens
    image_discrete = np.expand_dims(image_discrete, axis=-1)

    image_continuous = image.astype(np.float32) / 127.5 - 1.0
    image_continuous = np.reshape(image_continuous, (28, 28, 1))
    image_continuous = np.tile(image_continuous, (1, 1, 3))

    return {
        'data': {
            'data_continuous': image_continuous,
            'data_discrete': image_discrete,
        },
        'label': np.int32(x['label']),
    }


def mnist_dataset(batch_size, train) -> pygrain.DataLoader:
  loader = pygrain.load(
      source=tfds.data_source(name='mnist', split='all'),
      shuffle=True if train else False,
      shard_options=pygrain.ShardByJaxProcess(drop_remainder=True),
      transformations=[PreprocessExample()],
      batch_size=batch_size,
      drop_remainder=True,
      seed=0,
  )
  return loader

We start by plotting both the discrete and continuous images.

In [None]:
batch = next(iter(mnist_dataset(64, train=False)))
mnist_plot_images_continuous = batch['data']['data_continuous']

fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(mnist_plot_images_continuous[:64], axes.flatten()):
  ax.imshow(img)
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
mnist_plot_images_discrete = batch['data']['data_discrete']
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(mnist_plot_images_discrete[:64], axes.flatten()):
  ax.imshow(img[:, :, :, 0])
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
np.set_printoptions(linewidth=160)
mnist_plot_images_discrete[0][:, :, 0, 0]

# Define all diffusion model modules

## Noise process

We use Rectified Flow noise schedule $x_{t} = (1-t) x_0 + t \epsilon$, $\epsilon
\sim N(0, I)$ for the continuous data and a masking schedule for the discrete
data.

In [None]:
schedule_continuous = schedules.RFSchedule()
process_continuous = gaussian_process.GaussianProcess(schedule=schedule_continuous)

In [None]:
schedule_discrete = schedules.CosineDiscreteSchedule()
process_discrete = discrete_process.CategoricalProcess.masking_process(
    schedule=schedule_discrete, num_categories=256
)

In [None]:
process = base_process.NestedProcess(
    processes={
        'data_continuous': process_continuous,
        'data_discrete': process_discrete,
    }
)

Visualize noise process

In [None]:
num_noises = 7
fig, axes = plt.subplots(
    ncols=num_noises, figsize=(num_noises * 4, 4), sharex=True, sharey=True
)

x0 = jax.tree.map(lambda x: jnp.array(x), batch['data'])

corrupt_rng = jax.random.PRNGKey(10)
idx = 0
for time in jnp.linspace(1e-3, 1.0 - 1e-3, num=num_noises):
  time_dict = {
      'data_continuous': jnp.ones((1,)) * time,
      'data_discrete': jnp.ones((1,)) * time,
  }
  xt, _ = process.corrupt(key=corrupt_rng, x0=x0, time=time_dict)
  ax = axes[idx]
  ax.imshow(xt['data_continuous'][0])
  ax.axis('off')
  ax.set_title(f'Time = {time_dict['data_continuous'][0].squeeze().item()}')
  idx += 1

In [None]:
num_noises = 7
fig, axes = plt.subplots(
    ncols=num_noises, figsize=(num_noises * 4, 4), sharex=True, sharey=True
)

x0 = jax.tree.map(lambda x: jnp.array(x), batch['data'])

corrupt_rng = jax.random.PRNGKey(10)
idx = 0
for time in jnp.linspace(1e-3, 1.0 - 1e-3, num=num_noises):
  time_dict = {
      'data_continuous': jnp.ones((1,)) * time,
      'data_discrete': jnp.ones((1,)) * time,
  }
  xt, _ = process.corrupt(key=corrupt_rng, x0=x0, time=time_dict)
  ax = axes[idx]
  ax.imshow(xt['data_discrete'][0, ..., 0])
  ax.axis('off')
  ax.set_title(f'Time = {time_dict['data_discrete'][0].squeeze().item()}')
  idx += 1

## Define diffusion network backbone

First, we define diffusion backbone -- an architecture which takex `x` and
`conditioning_embeddings`, as well as `is_training` and returns the same type as
`x`.

Here, we use a small version of `Unet`.

In [None]:
base_backbone = unet.Unet(
    base_channels=32,
    channels_multiplier=(1, 2, 2),
    num_residual_blocks=(2, 2, 2),
    downsample_method=arch_typing.DownsampleType.AVG_POOL,
    upsample_method=arch_typing.UpsampleType.NEAREST,
    dropout_rate=(0.0, 0.0, 0.2),
    bottleneck_dropout_rate=0.2,
    self_attention_bool=(False, False, False),
    cross_attention_bool=(False, False, False),
    attention_normalize_qk=False,
    attention_use_rope=False,
    attention_rope_position_type=arch_typing.RoPEPositionType.SQUARE,
    attention_num_heads=8,
    attention_head_dim=-1,
    normalization_type=arch_typing.NormalizationType.RMS_NORM,
    normalization_num_groups=None,
    zero_init_output=False,
    activation='gelu',
    skip_connection_method=arch_typing.SkipConnectionMethod.NORMALIZED_ADD,
)

token_embedder = discrete_backbone.TokenEmbedder(
    process_num_categories=process.processes[
        'data_discrete'
    ].process_num_categories,
    embedding_dim=32,
    adapt_to_image_like_data=True,
)
embedder = {'data_continuous': None, 'data_discrete': token_embedder}
token_projector = discrete_backbone.DenseTokenProjector(
    num_categories=process.processes['data_discrete'].num_categories,
    embedding_dim=32,
    adapt_to_image_like_data=True,
)
projector = {'data_continuous': None, 'data_discrete': token_projector}

## Define the Multimodal conditional backbone

In [None]:
"""Backbones for multimodal data."""

################################################################################
# MARK: Type Aliases
################################################################################

Dtype = jax.typing.DTypeLike
Float = hd_typing.Float
DataArray = hd_typing.DataArray
DataTree = hd_typing.DataTree
ConditionalBackbone = arch_typing.ConditionalBackbone
BaseTokenEmbedder = discrete_backbone.BaseTokenEmbedder
BaseTokenProjector = discrete_backbone.BaseTokenProjector

################################################################################
# MARK: Multimodal
################################################################################


class BaseVectorizer(Protocol):
  """Vectorizer interface."""

  def vectorize(self, x: DataTree) -> DataArray:
    """Vectorizes the input data."""
    ...

  def unvectorize(self, original: DataTree, x: DataArray) -> DataTree:
    """Unvectorizes the input data."""
    ...


class ChannelVectorizer(BaseVectorizer):
  """Vectorizer that uses channels to vectorize the data."""

  def vectorize(self, x: DataTree) -> DataArray:
    leaves, _ = jax.tree_util.tree_flatten(x)
    return jnp.concatenate(leaves, axis=-1)

  def unvectorize(self, original: DataTree, x: DataArray) -> DataTree:
    leaves, treedef = jax.tree_util.tree_flatten(original)
    split_sizes = [leaf.shape[-1] for leaf in leaves]
    split_indices = np.cumsum(np.array(split_sizes[:-1]))
    # numpy is needed here to get a Concrete value
    split_leaves = jnp.split(x, split_indices, axis=-1)
    return jax.tree_util.tree_unflatten(treedef, split_leaves)


class ConditionalMultimodalBackbone(ConditionalBackbone):
  """Conditional multimodal backbone for diffusion models.

  Attributes:
    base_backbone: The base backbone to use for the discrete model. Can be any
      conditionl backbone such as MLP or UNet.
    token_embedder: The token embedder to use for the discrete model.
  """

  base_backbone: ConditionalBackbone
  embedder: PyTree[BaseTokenEmbedder | None]
  projector: PyTree[BaseTokenProjector | None]
  vectorizer: BaseVectorizer

  @nn.compact
  def __call__(
      self,
      x: DataTree,
      conditioning_embeddings: dict[str, Float['batch ...']],
      is_training: bool,
  ) -> DataTree:

    # Embed the tokens.
    token_embeddings = jax.tree.map(
        lambda embedder, x: embedder(x, is_training=is_training)
        if embedder is not None
        else x,
        self.embedder.unfreeze(),
        x,
        is_leaf=lambda x: x is None,
    )

    embeddings = self.vectorizer.vectorize(x=token_embeddings)

    # Output the result of the base backbone.
    backbone_outputs = self.base_backbone(
        x=embeddings,
        conditioning_embeddings=conditioning_embeddings,
        is_training=is_training,
    )

    # Unvectorize the result of the base backbone.
    output = self.vectorizer.unvectorize(
        original=token_embeddings,
        x=backbone_outputs,
    )

    output = jax.tree.map(
        lambda projector, x: projector(x, is_training=is_training)
        if projector is not None
        else x,
        self.projector.unfreeze(),
        output,
        is_leaf=lambda x: x is None,
    )

    output = utils.optional_bf16_to_fp32(output)
    return output

In [None]:
vectorizer = ChannelVectorizer()

backbone = ConditionalMultimodalBackbone(
    base_backbone=base_backbone,
    embedder=embedder,
    projector=projector,
    vectorizer=vectorizer,
)

## Define conditioning logic

Now, we define the conditioning embedders as well as the time encoder. The
conditioning encoder processes each conditioning (in the case of MNIST data,
each batch comes with its label (`label`)).

The conditioning encoder is a dictionary with key `label` (and here the value is
a `nn.Module` which is given by a simple `LabelEmbedding` module). If you want
to train a purely unconditional model, set `conditioning_embedders = {}`.

In [None]:
################################################################################
# Conditional diffusion.
################################################################################

ConditioningMechanism = arch_typing.ConditioningMechanism

conditioning_embedders = {
    'label': conditioning_encoder.LabelEmbedder(
        num_classes=10,
        num_features=256,
        conditioning_key='label',
    )
}

time_embedder_continuous = conditioning_encoder.SinusoidalTimeEmbedder(
    activation='gelu', embedding_dim=256, num_features=256
)
time_embedder_discrete = conditioning_encoder.SinusoidalTimeEmbedder(
    activation='gelu', embedding_dim=256, num_features=256
)
time_embedder = conditioning_encoder.NestedTimeEmbedder(
    time_embedders={
        'data_continuous': time_embedder_continuous,
        'data_discrete': time_embedder_discrete,
    }
)

encoder = conditioning_encoder.ConditioningEncoder(
    time_embedder=time_embedder,
    conditioning_embedders=conditioning_embedders,
    embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,
    conditioning_rules={
        'time': ConditioningMechanism.ADAPTIVE_NORM,
        'label': ConditioningMechanism.ADAPTIVE_NORM,
    },
)

## Putting all together into diffusion network

In [None]:
network = diffusion_network.MultiModalDiffusionNetwork(
    backbone_network=backbone,
    conditioning_encoder=encoder,
    prediction_type={'data_continuous': 'x0', 'data_discrete': 'logits'},
    data_dtype={'data_continuous': jnp.float32, 'data_discrete': jnp.int32},
    input_rescaler=None,
    time_rescaler=None,
)

Model visualization

In [None]:
summary_depth = 2  # @param {type: "integer"}

tabulate_fn = nn.tabulate(
    network,
    jax.random.PRNGKey(42),
    depth=summary_depth,
    console_kwargs={'force_jupyter': True, 'soft_wrap': True},
)

dummy_time = {
    'data_continuous': jnp.ones((1, 1, 1, 1)),
    'data_discrete': jnp.ones((1, 1, 1, 1, 1)),
}
dummy_xt = {
    'data_continuous': jnp.ones((1, 28, 28, 3)),
    'data_discrete': jnp.ones((1, 28, 28, 1, 1), dtype=jnp.int32),
}
dummy_conditioning = {'label': jnp.ones((1,)).astype(jnp.int32)}

print(
    tabulate_fn(
        dummy_xt,
        dummy_time,
        dummy_conditioning,
        is_training=False,
    )
)

## Define time sampler, optimizer and loss function

The time is sampled uniformly in the interval $[\epsilon,1 - \epsilon]$.

The loss is simply the $\ell_2$ loss.

In [None]:
time_sampler = time_sampling.JointNestedTimeSampler(
    samplers={
        "data_continuous": time_sampling.UniformTimeSampler(
            safety_epsilon=1e-3
        ),
        "data_discrete": time_sampling.UniformTimeSampler(safety_epsilon=1e-3),
    }
)

optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.0),
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale_by_schedule(optax.constant_schedule(value=5e-4)),
    optax.scale(-1.0),
)

loss_fn_continuous = gaussian_loss.NoWeightLoss()
loss_fn_discrete = discrete_loss.DiffusionCrossEntropyLoss(
    schedule=schedule_discrete
)
loss_fn = base_loss.NestedDiffusionLoss(
    losses={
        "data_continuous": loss_fn_continuous,
        "data_discrete": loss_fn_discrete,
    }
)

## Define the parameters loss function and gradient function

Here we define the loss function as well as gradient function to be dependent on
NN parameters. This is needed for training the neural network.

In [None]:
@jax.jit
def params_loss_fn(params, x0, conditioning, rng):
  time_rng, corrupt_rng = jax.random.split(rng, 2)
  time = time_sampler(key=time_rng, data_spec=x0)
  xt, targets = process.corrupt(key=corrupt_rng, x0=x0, time=time)
  output = network.apply(
      {'params': params},
      time=time,
      xt=xt,
      conditioning=conditioning,
      is_training=True,
      rngs={'dropout': rng},
  )
  losses = loss_fn(preds=output, targets=targets, time=time)
  leaves, _ = jax.tree_util.tree_flatten(losses)
  out = jnp.mean(jnp.stack(leaves))
  return out, {'loss': out}


grad_fn = jax.jit(jax.grad(params_loss_fn, has_aux=True))

Wrapping the whole update into `update_fn` since it makes the updates much
faster

In [None]:
@jax.jit
def update_fn(params, opt_state, x0, conditioning, rng):
  grads, metrics = grad_fn(params, x0, conditioning, rng)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state, metrics

## Train the model

Training the model should take less than 10 minutes.

In [None]:
nepochs = 20
batch_size = 256
epoch_size = 60000 // batch_size

In [None]:
rng = jax.random.PRNGKey(0)

import tqdm

params = network.initialize_variables(
    input_shape={
        'data_continuous': (1, 28, 28, 3),
        'data_discrete': (1, 28, 28, 1, 1),
    },
    conditioning_shape={'label': (1,)},
    key=rng,
    is_training=True,
)['params']

opt_state = optimizer.init(params)

train_iter = iter(mnist_dataset(batch_size, train=True))

losses = []
for epoch in tqdm.tqdm(range(1, nepochs + 1)):
  epoch_loss = steps = 0
  for i in range(epoch_size):
    # Read batch of data
    batch = next(train_iter)
    x0 = batch['data']
    conditioning = {'label': batch['label']}
    # Make the parameters update
    rng, _ = jax.random.split(rng)
    params, opt_state, metrics = update_fn(
        params, opt_state, x0, conditioning, rng
    )
    epoch_loss += metrics['loss']
    steps += 1
  print(f'Epoch = {epoch}, Cumulative epoch loss = {epoch_loss}')
  losses.append(epoch_loss)

In [None]:
plt.plot(losses)

## It's inference time

Below, we define the inference function. It creates a pure jax function which
takes `t`, `xt` and `c` to return the expected value of `x0`.

In [None]:
base_inference_fn = wrappers.FlaxLinenInferenceFn(
    network=network,
    params=params,
)
inference_fn = diffusion_inference.GuidedDiffusionInferenceFn(
    base_inference_fn=base_inference_fn
)

## Sampler -- time_schedule, stepper and sampler itself

In [None]:
stochasticity_level = 1.0  # Stochasticity coefficient in DDIM

time_schedule_continuous = time_scheduling.UniformTimeSchedule()
time_schedule_discrete = time_scheduling.UniformTimeSchedule()
time_schedule = time_scheduling.NestedTimeSchedule(
    time_schedules={
        'data_continuous': time_schedule_continuous,
        'data_discrete': time_schedule_discrete,
    }
)

stepper_continuous = gaussian_step_sampler.DDIMStep(
    corruption_process=process_continuous, stoch_coeff=stochasticity_level
)

num_sampling_steps = 100  # Number of denoising steps
stepper_discrete = discrete_step_sampler.UnMaskingStep(
    corruption_process=process_discrete
)

stepper = base_sampler.NestedSamplerStep(
    sampler_steps={
        'data_continuous': stepper_continuous,
        'data_discrete': stepper_discrete,
    }
)

sampler = sampling.DiffusionSampler(
    time_schedule=time_schedule, stepper=stepper, num_steps=num_sampling_steps
)
sampler = functools.partial(sampler, inference_fn=inference_fn)
sampler = jax.jit(jax.experimental.checkify.checkify(sampler))

## Sampling the data

*   First, we sample the data taking the conditioning from a batch of data,
    allowing to approximate $p(x_0)$

*   Second, we sample data with a given label, allowing to sample $p(x_0 | c)$

In [None]:
x0 = jax.tree.map(lambda x: jnp.array(x), x0)

initial_noise = process.sample_from_invariant(key=rng, data_spec=x0)

In [None]:
num_samples = 256
specific_label = 5

eval_iter = iter(mnist_dataset(num_samples, train=False))
eval_data = next(eval_iter)

################################################################################
# Sample conditionally using dataset
################################################################################

key = jax.random.PRNGKey(0)

conditioning = {"label": eval_data["label"]}
_, (out_cond, _) = sampler(
    rng=key, initial_noise=initial_noise, conditioning=conditioning
)

################################################################################
# Sample from a given label
################################################################################

key = jax.random.PRNGKey(1)
conditioning = {
    "label": jnp.ones((num_samples,)).astype(jnp.int32) * specific_label
}
_, (out_label, _) = sampler(
    rng=key, initial_noise=initial_noise, conditioning=conditioning
)

Visualize samples from $p(x_0)$

In [None]:
cur_mnist_plot_images = out_cond['data_continuous'].xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img)
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
cur_mnist_plot_images = out_cond['data_discrete'].xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64][..., 0], axes.flatten()):
  ax.imshow(img)
  ax.axis('off')

plt.tight_layout()
plt.show()

Visualize samples from $p(x_0 | c)$

In [None]:
cur_mnist_plot_images = out_label['data_continuous'].xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img)
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
cur_mnist_plot_images = out_label['data_discrete'].xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64][..., 0], axes.flatten()):
  ax.imshow(img)
  ax.axis('off')

plt.tight_layout()
plt.show()