# Example for generation using stochastic interpolants.

Here we consider the simple example of generating numbers (from MNIST) using the stochasting interpolant formalism.

We consider the simplest instantiation of the stochastic interpolants, which coincides with rectified flows.

Basically, suppose that we have two distributions of $d$-dimensional vectors $X_0$ and $X_1$, then we define the interpolant:
$$X_t = (1 -t) X_0 + t X_1$$
for $t \in [0, 1]$. Here we consider $X_0 \sim N(0, I_{d})$ and $X_1$ is random variable given by handwritten digits, with samples taken from the MNIST dataset. Here $d$ is the number of pixels of MNIST samples.


### Downloading dependencies.

We use the `swirl-dynamics` library for most of the heavy lifting, so we install it using pip.

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main

We also import all the necessary libraries.

In [None]:
from clu import metric_writers
import jax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from orbax import checkpoint
from swirl_dynamics.projects.debiasing.rectified_flow import models as reflow_models
from swirl_dynamics.projects.debiasing.stochastic_interpolants import interpolants
from swirl_dynamics.projects.debiasing.stochastic_interpolants import losses
from swirl_dynamics.projects.debiasing.stochastic_interpolants import models
from swirl_dynamics.projects.debiasing.stochastic_interpolants import trainers
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train
import tensorflow as tf
import tensorflow_datasets as tfds

### Define Hyper-Parameters

For simplicity we define the parameters inside a `ConfigDict`.

In [None]:
import ml_collections

config = ml_collections.ConfigDict()

# Parameters for the training steps.
config.initial_lr = 1e-6
config.peak_lr = 1e-4
config.warmup_steps = 10_000
config.num_train_steps = 100_000
config.end_lr = 1e-6
config.beta1 = 0.999
config.clip = 1.0
config.save_interval_steps = 1000
config.max_checkpoints_to_keep = 10

config.num_train_steps = 50_000
config.metric_aggregation_steps = 1000
config.eval_every_steps = 10_000
config.num_batches_per_eval = 2
config.batch_size_training = 64
config.batch_size_eval = 32

# Parameters for the instantation of the neural network.
# Here we will use a simple convoluational U-net with FilM layers
config.out_channels = 1
config.num_channels = (64, 128)
config.downsample_ratio = (2, 2)
config.num_blocks = 4
config.noise_embed_dim = 128
config.padding = "SAME"
config.use_attention = True
config.use_position_encoding = True
config.num_heads = 8
config.sigma_data = 0.31
config.seed = 666
config.ema_decay = 0.99

# The shapes of x_0 and x_1.
# The leading one represents the batch dimension.
config.input_shapes = ((1, 28, 28, 1), (1, 28, 28, 1))

### Downloading the data.

For the data we leverage the MNIST dataset in tensorflow datasets, to which we introduce an extra field with random Normal noise.

In [None]:
def get_mnist_dataset(split: str, batch_size: int, repeat: bool = True):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(
      # Change field name from "image" to "x" (required by `DenoisingModel`)
      # and normalize the value to [0, 1].
      lambda x: {
          "x_0": tf.random.normal(shape=x["image"].shape, mean=0.0),
          "x_1": tf.cast(x["image"], tf.float32) / 255.0,
      }
  )
  if repeat:
    ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds


# The standard deviation of the normalized dataset.
# This is useful for determining the diffusion scheme and preconditioning
# of the neural network parametrization.
DATA_STD = 0.31

Instantiating the dataloaders. This will download the data to disk so it can be fed directly to the training pipeline.

In [None]:
train_dataloader = get_mnist_dataset(
    split="train", batch_size=config.batch_size_training
)
eval_dataloader = get_mnist_dataset(
    split="test", batch_size=config.batch_size_eval
)

Here we extract one batch and we probe the elements inside a batch.


In [None]:
batch = next(iter(train_dataloader))
print(f"Keys of the batch: {batch.keys()}")
print(f"Shape of the x_0: {batch['x_0'].shape}")
print(f"Shape of the x_1: {batch['x_1'].shape}")

In [None]:
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(batch["x_0"][1, :, :, 0])
plt.title("Sample from initial distribution x_0")
plt.subplot(1, 2, 2)
plt.imshow(batch["x_1"][1, :, :, 0])
plt.title("Sample from target distribution x_1")

### Defining the stochastic interpolant optimizers

Here we define the learning rate schedule, for simplicity we use a linear ramp-up followed with a cosine decay schedule. This can be further tweaked but empirically, this has shown to provide reasonable results for this type of problems.

For the optimizer we use the Adam optimizer, and we also add a clipping mechanism to help avoid instabilities.


In [None]:
# Defining experiments through the config file.
schedule = optax.warmup_cosine_decay_schedule(
    init_value=config.initial_lr,
    peak_value=config.peak_lr,
    warmup_steps=config.warmup_steps,
    decay_steps=config.num_train_steps,
    end_value=config.end_lr,
)

optimizer = optax.chain(
    optax.adam(
        learning_rate=schedule,
        b1=config.beta1,
    ),
)

## Instantiating the model

In this case the model is a fully convolutional U-net model, using ResNet blocks with a Fourier embedding layer for the time.

Here this model parametrized the velocity vector field in the stochastic interpolant framework.

I.e., we have an interpolant of the form:
$$x_t = \alpha(t) x_0 + \beta(y) x_1 $$
where
$$\alpha(t) = 1-t, \qquad \text{and} \qquad \beta(t) = t.$$

Here we use the already defined ``LinerInterpolant`` class defined in [``interpolants``](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/projects/debiasing/stochastic_interpolants/interpolants.py).

In [None]:
interpolant = interpolants.LinearInterpolant()

We show how the interpolant progressible transforms the Gaussian noise to one of the target samples.

In [None]:
x_0_dummy = batch['x_0'][0:1, ..., 0]
x_1_dummy = batch['x_1'][0:1, ..., 0]

t_array = jnp.linspace(0, 1, 6)
fig, axs = plt.subplots(1, 6, figsize=(24, 4))
for ii, t in enumerate(t_array):
  x_t = interpolant(t[None], x_0_dummy, x_1_dummy)
  axs[ii].imshow(x_t[0, :, :])
  axs[ii].set_title(f'Sample from x_{t:<.3f}')

plt.show()

We consider a generative model that is instantiated by solving the following ODE:
$$\dot{x} = v_{\theta}(x, t), \qquad t \in [0, 1],$$
in this case, the model defined below parametrizes $v_{\theta}(x, t)$.

In [None]:
flow_model = reflow_models.RescaledUnet(
    out_channels=1,
    num_channels=(64, 128),
    downsample_ratio=(2, 2),
    num_blocks=4,
    noise_embed_dim=128,
    padding="SAME",
    use_attention=True,
    use_position_encoding=True,
    num_heads=8,
)

We also need to measure how the distance between the neural network and the speed would be considered. In this case we consider the loss:
$$|v_{\theta}(x_t, t) - \dot{x}_{t}|^2,$$
which can be furter simplified to
$$|v_{\theta}(x_t, t) - (x_1 - x_0)|^2.$$
using the fact that $\dot{x}_{t} = x_1 - x_0.$

An equivalent loss was already defined in the [``losses``](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/projects/debiasing/stochastic_interpolants/losses.py) module.

In [None]:
loss_stochastic_interpolant = losses.velocity_loss

Now we have all the required elements to create an instance of ``StochasticInterpolantModel``, which encapsulates all the information at the model level.

In [None]:
model = models.StochasticInterpolantModel(
    input_shape=(
        config.input_shapes[0][1],
        config.input_shapes[0][2],
        config.input_shapes[0][3],
    ),  # This must agree with the expected sample shape.
    flow_model=flow_model,
    # Defines the type of stochastic interpolant.
    interpolant=interpolant,
    # Defines the type of loss used for the training.
    loss_stochastic_interpolant=loss_stochastic_interpolant,
    num_eval_cases_per_lvl=8,
)

### Building the trainer

Now, we just need to instantiate the trainer, which contains all the information to run the training loop. This includes the model, the optimizer, and the checkpointer.

In [None]:
# Defining the trainer.
trainer = trainers.StochasticInterpolantTrainer(
    model=model,
    rng=jax.random.key(config.seed),
    optimizer=optimizer,
    ema_decay=config.ema_decay,
)

# Setting up checkpointing.
ckpt_options = checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.save_interval_steps,
    max_to_keep=config.max_checkpoints_to_keep,
)

# Sets up the working directory.
workdir = "/content"  # typical current position in Colab.

In [None]:
### If you need to remove the checkpoint to start from scratch.
# !rm -Rf /content/checkpoints

### Running the training loop.

We run the training loop.

Here the seek to solve the problem

$$ \min_{\theta} \mathbb{E}_{t \sim U[0, 1]} \mathbb{E}_{(x_0, x_1) \in \mu_0 \otimes \mu_1} \left \| \dot{x}_t - v_{\theta}(x_t, t)  \right \|^2,$$
where $x_t = t x_1 + (1-t) x_0$, and $\mu_0 \sim N(0, 1)$ and $\mu_1$ is the distribution of MNIST digits.


This loss can be further simplified as
$$ \min_{\theta} \mathbb{E}_{t \sim U[0, 1]} \mathbb{E}_{(x_0, x_1) \in \mu_0 \otimes \mu_1}   | v_{\theta}(x_t, t)|^2  - 2 (x_1 - x_0) \cdot  v_{\theta}(x_t, t),$$
using the fact that $\dot{x}_t = x_1 - x_0$ and that $\dot{x}_t$ is independent of $\theta$.

Note the full training step it takes around 15-20 mins in a TPU v6e (Trillium).


In [None]:
# Run training loop.

train.run(
    train_dataloader=train_dataloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=config.num_train_steps,
    metric_aggregation_steps=config.metric_aggregation_steps,  # 30
    eval_dataloader=eval_dataloader,
    eval_every_steps=config.eval_every_steps,
    num_batches_per_eval=config.num_batches_per_eval,
    metric_writer=metric_writers.create_default_writer(
        workdir, asynchronous=False
    ),
    callbacks=(
        callbacks.TqdmProgressBar(
            total_train_steps=config.num_train_steps,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically.
        callbacks.TrainStateCheckpoint(
            base_dir=workdir,
            options=ckpt_options,
        ),
        # TODO add a plot callback.
    ),
)

# Running Inference

Loading extra libraries for running inference.

In [None]:
import functools
from swirl_dynamics.lib.solvers import ode as ode_solvers
from tqdm import tqdm

Define the dataloader to run inference.

In [None]:
test_dataloader = get_mnist_dataset(
    split="test", batch_size=config.batch_size_eval, repeat=False
)

### Load the last trained model and define the dynamics.

In [None]:
trained_state = trainers.TrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)

In [None]:
latent_dynamics_fn = ode_solvers.nn_module_to_dynamics(
    model.flow_model,
    autonomous=False,
    is_training=False,
)

We define the ODE solver, (here Runge-Kutta 4th order), and other details such as the number of steps.

In [None]:
num_sampling_steps = 128

integrator = ode_solvers.RungeKutta4()
integrate_fn = functools.partial(
    integrator,
    latent_dynamics_fn,
    tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps),
    params=trained_state.model_variables,
)

integrate_fn_jit = jax.jit(integrate_fn)

In [None]:
batch = next(iter(test_dataloader))
print(f"Shape of the x_0 condition, {batch['x_0'].shape}")
out_put = integrate_fn_jit(batch["x_0"])
print(f"Shape of the generated x_1 {out_put.shape}")

In [None]:
plt.imshow(out_put[-1, 0, :, :, 0])

### Running Inference Loop.

This may take a non-negligible amount of time.

In [None]:
input_list = []
output_list = []

for ii, batch in tqdm(enumerate(test_dataloader)):
  input_list.append(batch["x_0"])

  output = np.array(
      integrate_fn_jit(batch["x_0"])[-1].reshape(
          (-1, config.input_shapes[1][1], config.input_shapes[1][2])
      )
  )
  output_list.append(output)

In [None]:
output_array = np.concatenate(output_list, axis=0)
print(f"Shape of the output array: {output_array.shape}")

In [None]:
num_plots = 6
fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4))
num_samples = output_array.shape[0]
idx_samples = np.linspace(0, num_samples - 1, num_plots).astype(int)
for ii, idx in enumerate(idx_samples):
  axs[ii].imshow(output_array[idx, :, :])
  axs[ii].set_title(f"Sample number: {idx}")

plt.show()