# Training MNIST using discrete diffusion 🚀

In this colab we showcase how to train a **discrete** diffusion model on MNIST dataset. This colab can run on any colab backend.

In [None]:
################################################################################
# Common modules
################################################################################

import dataclasses
import functools
from etils import ecolab
import flax.linen as nn
import grain.python as pygrain
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow_datasets as tfds
import tqdm

################################################################################
# Hackable diffusion modules
################################################################################

cell_autoreload = True  # @param{type: "boolean"}

with ecolab.adhoc(
    reload=["hackable_diffusion"],
    invalidate=False,
    cell_autoreload=cell_autoreload,
):
  from hackable_diffusion import hd

In [None]:
diffusion_network = hd.diffusion_network
time_sampling = hd.time_sampling
discrete = hd.corruption.discrete
schedules = hd.corruption.schedules
arch_typing = hd.architecture.arch_typing
conditioning_encoder = hd.architecture.conditioning_encoder
discrete_backbone = hd.architecture.discrete
unet = hd.architecture.unet
wrappers = hd.inference.wrappers
diffusion_inference = hd.inference.diffusion_inference
discrete_loss = hd.loss.discrete
discrete_step_sampler = hd.sampling.discrete_step_sampler
sampling = hd.sampling.sampling
time_scheduling = hd.sampling.time_scheduling

# Prepare MNIST data

Create py-grain data structures for convenient batching and loading.

MNIST data is $28 \times 28 \times 1$.

In [None]:
@dataclasses.dataclass(frozen=True)
class PreprocessExample(pygrain.MapTransform):
  """Preprocesses an example."""

  def map(self, x: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Converts everything to int32."""

    image = x['image'].astype(np.int32)
    image = np.reshape(image, (28, 28, 1))
    # We add additional dimension for the tokens
    image = np.expand_dims(image, axis=-1)

    return {
        'data': image,
        'label': np.int32(x['label']),
    }


def mnist_dataset(batch_size, train) -> pygrain.DataLoader:
  loader = pygrain.load(
      source=tfds.data_source(name='mnist', split='all'),
      shuffle=True if train else False,
      shard_options=pygrain.ShardByJaxProcess(drop_remainder=True),
      transformations=[PreprocessExample()],
      batch_size=batch_size,
      drop_remainder=True,
      seed=0,
  )
  return loader

In [None]:
mnist_plot_images = next(iter(mnist_dataset(64, train=False)))['data']
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img[:, :, :, 0])
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
np.set_printoptions(linewidth=160)
mnist_plot_images[0][:, :, 0, 0]

# Define all diffusion model modules

## Noise process

We use cosine discrete schedule

In [None]:
schedule = schedules.CosineDiscreteSchedule()
process = discrete.CategoricalProcess.masking_process(
    schedule=schedule, num_categories=256
)

In [None]:
invariant_probs_masking = (0.0,) * process.num_categories + (1.0,)
process.invariant_probs == invariant_probs_masking

In [None]:
process.is_masking

Visualize noise process

In [None]:
num_noises = 7
fig, axes = plt.subplots(
    ncols=num_noises, figsize=(num_noises * 4, 4), sharex=True, sharey=True
)

corrupt_rng = jax.random.PRNGKey(10)
idx = 0
for time in jnp.linspace(1e-3, 1.0 - 1e-3, num=num_noises):
  xt, targets = process.corrupt(
      key=corrupt_rng,
      x0=jnp.array(mnist_plot_images),
      time=jnp.ones((1,)) * time,
  )
  ax = axes[idx]
  ax.imshow(xt[0, :, :, :, 0])
  ax.axis('off')
  ax.set_title(f'Time = {time}')
  idx += 1

## Define diffusion network backbone

First, we define diffusion backbone -- an architecture which takex `x` and `conditioning_embeddings`, as well as `is_training` and returns the same type as `x`.

Here, we use a small version of `Unet`.

In [None]:
base_backbone = unet.Unet(
    base_channels=32,
    channels_multiplier=(1, 2, 2),
    num_residual_blocks=(2, 2, 2),
    downsample_method=arch_typing.DownsampleType.AVG_POOL,
    upsample_method=arch_typing.UpsampleType.NEAREST,
    dropout_rate=(0.0, 0.0, 0.2),
    bottleneck_dropout_rate=0.2,
    self_attention_bool=(False, False, False),
    cross_attention_bool=(False, False, False),
    attention_normalize_qk=False,
    attention_use_rope=False,
    attention_rope_position_type=arch_typing.RoPEPositionType.SQUARE,
    attention_num_heads=8,
    attention_head_dim=-1,
    normalization_type=arch_typing.NormalizationType.RMS_NORM,
    normalization_num_groups=None,
    zero_init_output=False,
    activation='gelu',
    skip_connection_method=arch_typing.SkipConnectionMethod.NORMALIZED_ADD,
)

backbone = discrete_backbone.ConditionalDiscreteBackbone(
    base_backbone=base_backbone,
    token_embedder=discrete_backbone.TokenEmbedder(
        process_num_categories=process.process_num_categories,
        embedding_dim=32,
        adapt_to_image_like_data=True,
    ),
    token_projector=discrete_backbone.DenseTokenProjector(
        embedding_dim=32,
        num_categories=256,
        adapt_to_image_like_data=True,
    ),
)

## Define conditioning logic

Now, we define the conditioning embedders as well as the time encoder. The conditioning encoder processes each conditioning (in the case of MNIST data, each batch comes with its label (`label`)).

The conditioning encoder is a dictionary with key `label` (and here the value is a `nn.Module` which is given by a simple `LabelEmbedding` module). If you want to train a purely unconditional model, set `conditioning_embedders = {}`.



In [None]:
################################################################################
# Conditional diffusion.
################################################################################

conditioning_embedders = {
    'label': conditioning_encoder.LabelEmbedder(
        num_classes=10,
        num_features=256,
        conditioning_key='label',
    )
}

encoder = conditioning_encoder.ConditioningEncoder(
    time_embedder=conditioning_encoder.SinusoidalTimeEmbedder(
        activation='gelu', embedding_dim=256, num_features=256
    ),
    conditioning_embedders=conditioning_embedders,
    embedding_merging_method=arch_typing.EmbeddingMergeMethod.SUM,
    conditioning_rules={
        'time': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
        'label': arch_typing.ConditioningMechanism.ADAPTIVE_NORM,
    },
)

## Putting all together into diffusion network

In [None]:
network = diffusion_network.DiffusionNetwork(
    backbone_network=backbone,
    conditioning_encoder=encoder,
    prediction_type='logits',
    data_dtype=jnp.int32,
)

Model visualization

In [None]:
summary_depth = 2  # @param {type: "integer"}

tabulate_fn = nn.tabulate(
    network,
    jax.random.PRNGKey(42),
    depth=summary_depth,
    console_kwargs={"force_jupyter": True, "soft_wrap": True},
)

dummy_time = jnp.ones((1,))
dummy_xt = jnp.ones((1, 28, 28, 1, 1), dtype=jnp.int32)
dummy_conditioning = {"label": jnp.ones((1,), dtype=jnp.int32)}

print(
    tabulate_fn(
        dummy_time,
        dummy_xt,
        dummy_conditioning,
        is_training=False,
    )
)

## Define time sampler, optimizer and loss function

The time is sampled uniformly in the interval $[\epsilon,1 - \epsilon]$.

The loss is simply the $\ell_2$ loss.

In [None]:
time_sampler = time_sampling.UniformTimeSampler(safety_epsilon=1e-3)

optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.0),
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale_by_schedule(optax.constant_schedule(value=5e-4)),
    optax.scale(-1.0),
)

loss_fn = discrete_loss.DiffusionCrossEntropyLoss(schedule=schedule)

## Define the parameters loss function and gradient function

Here we define the loss function as well as gradient function to be dependent on NN parameters. This is needed for training the neural network.

In [None]:
@jax.jit
def params_loss_fn(params, x0, conditioning, rng):
  time_rng, corrupt_rng = jax.random.split(rng, 2)
  time = time_sampler(key=time_rng, data_spec=x0)
  xt, targets = process.corrupt(key=corrupt_rng, x0=x0, time=time)
  output = network.apply(
      {'params': params},
      time=time,
      xt=xt,
      conditioning=conditioning,
      is_training=True,
      rngs={'dropout': rng},
  )
  out = jnp.mean(loss_fn(preds=output, targets=targets, time=time))
  return out, {'loss': out}


grad_fn = jax.jit(jax.grad(params_loss_fn, has_aux=True))

Wrapping the whole update into `update_fn` since it makes the updates much faster

In [None]:
@jax.jit
def update_fn(params, opt_state, x0, conditioning, rng):
  grads, metrics = grad_fn(params, x0, conditioning, rng)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state, metrics

# Train the model

In [None]:
nepochs = 15
batch_size = 256
epoch_size = 60000 // batch_size

rng = jax.random.PRNGKey(0)

params = network.initialize_variables(
    input_shape=(1, 28, 28, 1, 1),
    conditioning_shape={'label': (1,)},
    key=rng,
    is_training=True,
)['params']

In [None]:
opt_state = optimizer.init(params)

train_iter = iter(mnist_dataset(batch_size, train=True))

losses = []
for epoch in tqdm.tqdm(range(1, nepochs + 1)):
  epoch_loss = steps = 0
  for i in range(epoch_size):
    # Read batch of data
    batch = next(train_iter)
    x0 = batch['data']
    conditioning = {'label': batch['label']}
    # Make the parameters update
    rng, _ = jax.random.split(rng)
    params, opt_state, metrics = update_fn(
        params, opt_state, x0, conditioning, rng
    )
    epoch_loss += metrics['loss']
    steps += 1
  print(f'Epoch = {epoch}, Cumulative epoch loss = {epoch_loss}')
  losses.append(epoch_loss)

In [None]:
plt.plot(losses)

# It's inference time

Below, we define the inference function.
It creates a pure jax function which takes `t`, `xt` and `c` to return the expected value of `x0`.

In [None]:
base_inference_fn = wrappers.FlaxLinenInferenceFn(
    network=network,
    params=params,
)
inference_fn = diffusion_inference.GuidedDiffusionInferenceFn(
    base_inference_fn=base_inference_fn
)

## Sampler -- time_schedule, stepper and sampler itself

In [None]:
num_sampling_steps = 28 * 28  # Number of denoising steps
time_schedule = time_scheduling.UniformTimeSchedule()
stepper = discrete_step_sampler.UnMaskingStep(corruption_process=process)

sampler = sampling.DiffusionSampler(
    time_schedule=time_schedule, stepper=stepper, num_steps=num_sampling_steps
)
sampler = functools.partial(sampler, inference_fn=inference_fn)
sampler = jax.jit(jax.experimental.checkify.checkify(sampler))

## Sampling the data

* First, we sample the data taking the conditioning from a batch of data, allowing to approximate $p(x_0)$

* Second, we sample data with a given label, allowing to sample $p(x_0 | c)$

In [None]:
num_samples = 64
data_spec = jnp.ones((num_samples, 28, 28, 1, 1), dtype=jnp.int32)
specific_label = 5

eval_iter = iter(mnist_dataset(num_samples, train=False))
eval_data = next(eval_iter)

################################################################################
# Sample conditionally using dataset
################################################################################

key = jax.random.PRNGKey(0)

initial_noise = process.sample_from_invariant(key=key, data_spec=data_spec)
conditioning = {"label": eval_data["label"]}
_, (out_cond, _) = sampler(
    rng=key, initial_noise=initial_noise, conditioning=conditioning
)

################################################################################
# Sample from a given label
################################################################################

key = jax.random.PRNGKey(1)
initial_noise = process.sample_from_invariant(key=key, data_spec=data_spec)
conditioning = {
    "label": jnp.ones((num_samples,)).astype(jnp.int32) * specific_label
}
_, (out_label, _) = sampler(
    rng=key, initial_noise=initial_noise, conditioning=conditioning
)

Visualize true dataset

In [None]:
cur_mnist_plot_images = eval_data['data']
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img[:, :, :, 0])
  ax.axis('off')

plt.tight_layout()
plt.show()

Visualize samples from $p(x_0)$

In [None]:
np.set_printoptions(linewidth=160)
eval_data['data'][0][:, :, :, 0].reshape((28, 28))

In [None]:
cur_mnist_plot_images = out_cond.xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img[:, :, :, 0])
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
np.set_printoptions(linewidth=160)
out_cond.xt[0][:, :, :, 0].reshape((28, 28))

Visualize samples from $p(x_0 | c)$

In [None]:
cur_mnist_plot_images = out_label.xt
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for img, ax in zip(cur_mnist_plot_images[:64], axes.flatten()):
  ax.imshow(img[:, :, :, 0])
  ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
np.set_printoptions(linewidth=160)
out_label.xt[0][:, :, :, 0].reshape((28, 28))