# Training in 2D 🚀

In this colab we showcase how to train a diffusion model on a two dimensional
dataset, here a simple mixture of Gaussians. This colab can run on any colab
backend.

In [None]:
################################################################################
# Common modules
################################################################################

import functools
from etils import ecolab
import flax.linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
################################################################################
# Datasets
################################################################################

from sklearn.datasets import make_swiss_roll
import tqdm

################################################################################
# Hackable diffusion modules
################################################################################

cell_autoreload = True  # @param{type: "boolean"}

with ecolab.adhoc(
    reload=["hackable_diffusion"],
    invalidate=False,
    cell_autoreload=cell_autoreload,
):
  from hackable_diffusion import hd

In [None]:
diffusion_network = hd.diffusion_network
time_sampling = hd.time_sampling
gaussian = hd.corruption.gaussian
schedules = hd.corruption.schedules
arch_typing = hd.architecture.arch_typing
conditioning_encoder = hd.architecture.conditioning_encoder
mlp = hd.architecture.mlp
wrappers = hd.inference.wrappers
diffusion_inference = hd.inference.diffusion_inference
gaussian_loss = hd.loss.gaussian
time_scheduling = hd.sampling.time_scheduling
sampling = hd.sampling.sampling
gaussian_step_sampler = hd.sampling.gaussian_step_sampler

# Define data distribution

In [None]:
data, _ = make_swiss_roll(n_samples=10_000, noise=0.5)
# Make two-dimensional to easen visualization
data = data[:, [0, 2]]

data = (data - data.mean()) / data.std()

small_data_subset = data[:2048]

plt.scatter(data[:, 0], data[:, 1])

# Define all diffusion model modules

## Noise process

In [None]:
schedule = schedules.RFSchedule()
process = gaussian.GaussianProcess(schedule=schedule)

Visualize noise process

In [None]:
num_noises = 7
fig, axes = plt.subplots(
    ncols=num_noises, figsize=(num_noises * 4, 4), sharex=True, sharey=True
)

corrupt_rng = jax.random.PRNGKey(10)
idx = 0
for time in jnp.linspace(1e-3, 1.0 - 1e-3, num=num_noises):
  xt, _ = process.corrupt(
      key=corrupt_rng,
      x0=jnp.array(small_data_subset),
      time=jnp.ones((1,)) * time,
  )
  ax = axes[idx]
  ax.scatter(xt[:, 0], xt[:, 1])
  ax.set_title(f'Time = {time}')
  idx += 1

## Define diffusion network -- backbone & conditioning

In [None]:
module = mlp.ConditionalMLP(
    hidden_sizes_preprocess=[32, 16],
    hidden_sizes_postprocess=[32, 16],
    activation='gelu',
    zero_init_output=False,
    dtype=jnp.float32,
    dropout_rate=0.0,
    conditioning_mechanism=arch_typing.ConditioningMechanism.CONCATENATE,
)

conditioning_embedders = {}

encoder = conditioning_encoder.ConditioningEncoder(
    time_embedder=conditioning_encoder.SinusoidalTimeEmbedder(
        activation='gelu', embedding_dim=128, num_features=128
    ),
    conditioning_embedders=conditioning_embedders,
    embedding_merging_method='concat',
    conditioning_rules={
        'time': arch_typing.ConditioningMechanism.CONCATENATE,
    },
)


network = diffusion_network.DiffusionNetwork(
    backbone_network=module,
    conditioning_encoder=encoder,
    prediction_type='velocity',
)

In [None]:
summary_depth = 2  # @param {type: "integer"}

tabulate_fn = nn.tabulate(
    network,
    jax.random.PRNGKey(42),
    depth=summary_depth,
    console_kwargs={"force_jupyter": True, "soft_wrap": True},
)

dummy_time = jnp.ones((1,))
dummy_xt = jnp.ones((1, 2))
dummy_conditioning = None

print(
    tabulate_fn(
        dummy_time,
        dummy_xt,
        dummy_conditioning,
        is_training=False,
    )
)

## Define time sampler, optimizer and loss function

In [None]:
time_sampler = time_sampling.UniformTimeSampler(safety_epsilon=1e-3)

optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.0),
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale_by_schedule(optax.constant_schedule(value=1e-3)),
    optax.scale(-1.0),
)

loss_fn = gaussian_loss.NoWeightLoss()

## Define the parameters loss function and gradient function

In [None]:
@jax.jit
def params_loss_fn(params, x0, rng):
  time_rng, corrupt_rng = jax.random.split(rng, 2)
  time = time_sampler(key=time_rng, data_spec=x0)
  xt, targets = process.corrupt(key=corrupt_rng, x0=x0, time=time)
  output = network.apply(
      {'params': params},
      time=time,
      xt=xt,
      conditioning=None,
      is_training=True,
  )
  out = jnp.mean(loss_fn(preds=output, targets=targets, time=time))
  return out, {'loss': out}


grad_fn = jax.jit(jax.grad(params_loss_fn, has_aux=True))

Wrapping the whole update into `update_fn` since it makes the updates much faster

In [None]:
@jax.jit
def update_fn(params, opt_state, x0, rng):
  grads, metrics = grad_fn(params, x0, rng)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return params, opt_state, metrics

## Train the model

In [None]:
nepochs = 100
batch_size = 256

In [None]:
rng = jax.random.PRNGKey(0)

params = network.initialize_variables(
    input_shape=(1, 2),
    conditioning_shape=None,
    key=rng,
    is_training=True,
)['params']

opt_state = optimizer.init(params)

losses = []
for epoch in tqdm.tqdm(range(1, nepochs)):
  epoch_loss = steps = 0
  for i in range(0, len(data), batch_size):
    batch = data[i : i + batch_size]
    # Make the parameters update
    rng, _ = jax.random.split(rng)
    params, opt_state, metrics = update_fn(params, opt_state, batch, rng)
    epoch_loss += metrics['loss']
    steps += 1
    rng, _ = jax.random.split(rng)
  print(f'Epoch = {epoch}, Cumulative epoch loss = {epoch_loss}')
  losses.append(epoch_loss)

In [None]:
plt.plot(losses)

# Sampling

In [None]:
shape = (2,)
base_inference_fn = wrappers.FlaxLinenInferenceFn(
    network=network,
    params=params,
)
inference_fn = diffusion_inference.GuidedDiffusionInferenceFn(
    base_inference_fn=base_inference_fn
)
time_schedule = time_scheduling.UniformTimeSchedule()
stepper = gaussian_step_sampler.DDIMStep(
    corruption_process=process, stoch_coeff=1.0
)

sampler = sampling.DiffusionSampler(
    time_schedule=time_schedule, stepper=stepper, num_steps=100
)
sampler = functools.partial(sampler, inference_fn=inference_fn)
sampler = jax.jit(jax.experimental.checkify.checkify(sampler))

key = jax.random.PRNGKey(0)
eval_batch_size = 2048
initial_noise = jax.random.normal(key=key, shape=(eval_batch_size, *shape))
eval_cond = jnp.zeros(shape=(eval_batch_size,))
conditioning = {"mean": eval_cond}

err, (out, all_steps) = sampler(
    rng=key, initial_noise=initial_noise, conditioning=conditioning
)

In [None]:
plt.plot(out.xt[:, 0], out.xt[:, 1], '*')
plt.plot(data[:, 0], data[:, 1], '*')
plt.legend(['generated', 'original'])