# Flow-map Matching Distillation Demo

In this script, we provide a simple example of how to use flow-map matching for distillation.

In this case, we will train a conditional rectified flow model, instantiated using the stochastic interpolant formalism, to generate numbers from the MNIST dataset.

---

### Training the Teacher Model

We use a rectified flow to bridge two distributions, namely $\rho_0$ and $\rho_1$, where the former is a normal distribution and the latter is the distribution of handwritten digits from the MNIST dataset.

Thus, we consider a straight path between these two distributions given by:

$$x_t = (1-t) x_0 + t x_1,$$

where $x_0 \sim \rho_0$ and $x_1 \sim \rho_1$.

Following this path, we train a rectified flow model by minimizing:

$$ \min_{\phi} \mathbb{E}_{\rho_0, \rho_1}\mathbb{E}_{t \sim U[0,1]} \| v_{\phi}(x_t, t) - \dot{x_t} \|^2,$$
where $x_t = (1-t) x_0 + t x_1$ and $\dot{x_t} = x_1 - x_0$.

---

### Distillation Through Flow-Map Matching

Once $v_{\phi}$ is trained, we use a flow-map matching approach to distill it.

In a nutshell, we consider the flow-map $X^{t, s}$ as the map such that:

$$ X^{t, s}(x_s) = \int_{s}^{t} v_{\phi}(x_{\tau}, \tau) d \tau  + x_s = x_t.$$

I.e., along a trajectory, it maps $x_s$ to $x_t$.

Now, following [1], we take the derivative with respect to $t$ and use some elementary calculus to find that:
$$\partial_t X^{t, s}(x_s) = v_{\phi}(x_t, t)$$

Then, by replacing $x_t$ with $X^{t, s}(x_s)$, we find that a flow-map needs to satisfy:

$$\partial_t X^{t, s}(x_s) = v_{\phi}( X^{t, s}(x_s), t).$$

This is called the Lagrangian formulation of the flow-map. We then proceed to train the flow-map by softly imposing the property above, i.e.:

$$\min_{\theta} \mathbb{E}_{\rho_0, \rho_1} \int_{[0,1]^2} \| \partial_t X_{\theta}^{t, s}(x_s) -  v_{\phi}( X_{\theta}^{t, s}(x_s), t) \|^2 ds\, dt, $$

where $x_s = (1-s) x_0 + s x_1$, and we also impose the condition $X_{\theta}^{s, s}(x_s) = x_s$ within the architecture, namely we define

$$X_{\theta}^{t, s}(x) = (1 - (t-s) ) x + (t-s) f_{\theta}(x, t, s).$$

---

### Sampling Using the Flow-Map Model

For sampling, we can use either one-shot or few-shot sampling.

* **One-shot**: Here, we just use $X^{1, 0}$, which takes a sample from $\rho_0$ and maps it to a sample of $\rho_1$.

* **Few-shot**: We assume a partition of $[0, 1]$, e.g., $0=t_0< t_1< \dots < t_{n-1} < t_n = 1$, and we factorize $X^{1,0}(x_0) = X^{1, t_{n-1}} \circ X^{t_{n-1}, t_{n-2}} \circ \dots \circ X^{t_1, 0} (x_0)$.

---

## References

[1] Flow Map Matching with Stochastic Interpolants: A Mathematical Framework for Consistency Models. Nicholas M. Boffi, Michael S. Albergo, and Eric Vanden-Eijnden.

As usual we start by installing the `swirl-dynamics` library.

In [None]:
!pip install git+https://github.com/google-research/swirl-dynamics.git@main --quiet

In [None]:
import functools
import os
from clu import metric_writers
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from orbax import checkpoint

# Imports from swirl-dynamics codebase.
from swirl_dynamics.lib.solvers import ode as ode_solvers
from swirl_dynamics.projects.debiasing.stochastic_interpolants import backbones
from swirl_dynamics.projects.debiasing.stochastic_interpolants import interpolants
from swirl_dynamics.projects.debiasing.stochastic_interpolants import losses
from swirl_dynamics.projects.debiasing.stochastic_interpolants import models as flow_models
from swirl_dynamics.projects.debiasing.stochastic_interpolants import trainers
from swirl_dynamics.projects.distillation.flow_map_matching import models as flow_map_models
from swirl_dynamics.projects.distillation.flow_map_matching import trainers as flow_map_trainers
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train

import tensorflow as tf
import tensorflow_datasets as tfds


We define the dataloader. Here we leverage the data in `tfds`, and we apply the appropiate transformations. Following the convention above we have that each batch will have three fields: 
- `x_0`: a sample from $\rho_0$, i.e., a normal distribution,
- `x_1`: a sample from $\rho_1$, a sample from the MNIST dataset, and
- `emb:label`: this is the class of the `x_1` sample, i.e., to which digits it corresponds.

In [None]:
def get_mnist_dataset(split: str, batch_size: int):
  ds = tfds.load("mnist", split=split)
  ds = ds.map(
      # Change field name from "image" to "x" (required by `DenoisingModel`)
      # and normalize the value to [0, 1].
      lambda x: {"x_0": tf.random.normal(shape=x["image"].shape, mean=0.),
                 "x_1": tf.cast(x["image"], tf.float32) / 255.0,
                 "emb:label": x["label"]}
  )
  ds = ds.repeat()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)
  ds = ds.as_numpy_iterator()
  return ds

We define the parameters for the neural architecture and the training hyperparameters. For simplicity we store them in a configDict and we will use them when necessary.

In [None]:
import ml_collections

config  = ml_collections.ConfigDict()

config.initial_lr = 1e-6
config.peak_lr  = 1e-3
config.warmup_steps  = 10_000
config.num_train_steps = 100_000
config.num_train_steps_flow_map = 150_000
config.end_lr = 1e-6
config.beta1 = 0.999
config.clip = 10.0
config.save_interval_steps = 1000
config.max_checkpoints_to_keep = 10

config.metric_aggregation_steps = 100
config.eval_every_steps = 5_000
config.num_batches_per_eval = 1
config.batch_size = 256
config.batch_size_flow_map = 32

config.out_channels = 1
config.num_channels = (64, 128)
config.downsample_ratio = (2, 2)
config.num_blocks = 4
config.noise_embed_dim = 128
config.padding = "SAME"
config.use_attention = True
config.use_position_encoding = True
config.num_heads = 8
config.sigma_data = 0.31
config.seed = 666
config.ema_decay = 0.99


config.input_shapes = ((1, 28, 28, 1), (1, 28, 28, 1))

## Defining the rectified flow (teacher) model.

Here we instantiate the model as a neural architecture, then we wrap it around using a `model` object, then we define the hyperparameters for training, the trained and we simply train the model. Given that this is already explained for the rectified flow [here](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/projects/debiasing/rectified_flow/colab/demo_reflow.ipynb) and the stochastic interpolants [here](https://github.com/google-research/swirl-dynamics/tree/main/swirl_dynamics/projects/debiasing/stochastic_interpolants/colabs) we only provide a quick overview of the training pipeline.


We define, instantiate and wrap the conditional rectified flow model using an stochastic interpolant.

In [None]:
flow_model = flow_models.RescaledUnet(
    out_channels=1,
    num_channels=config.num_channels,
    downsample_ratio=config.downsample_ratio,
    num_blocks=config.num_blocks,
    noise_embed_dim=config.noise_embed_dim,
    padding=config.padding,
    use_attention=config.use_attention,
    use_position_encoding=config.use_position_encoding,
    num_heads=config.num_heads,
    cond_embed_fn=backbones.MergeCategoricalEmbCond,
    cond_embed_kwargs={
        "cond_key": "emb:label",
        "num_classes": 10,
        "features_embedding": config.noise_embed_dim,
        }
)

model = flow_models.ConditionalStochasticInterpolantModel(
    input_shape= config.input_shapes[0][1:],
    cond_shape = {"emb:label": ()},
    conditioning_keys=("emb:label",),
    flow_model=flow_model,
    interpolant=interpolants.RectifiedFlow(),  # Defines the type of stochastic interpolant.
    loss_stochastic_interpolant=losses.velocity_loss,  # Defines the type of loss used for the training.
)

Define the optimizer (with its corresponding schedule), and the trainer (including the checkpointer). We also define the workding directory to save the checkpoints.

In [None]:
schedule = optax.warmup_cosine_decay_schedule(
    init_value=config.initial_lr,
    peak_value=config.peak_lr,
    warmup_steps=config.warmup_steps,
    decay_steps=config.num_train_steps,
    end_value=config.end_lr,
)

optimizer = optax.chain(
    optax.clip(config.clip),
    optax.adam(
        learning_rate=schedule,
        b1=config.beta1,
    ),
)
trainer = trainers.DistributedStochasticInterpolantTrainer(
    model=model,
    rng=jax.random.PRNGKey(config.seed),
    optimizer=optimizer,
    ema_decay=config.ema_decay,
)

# Setting up checkpointing.
ckpt_options = checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.save_interval_steps,
    max_to_keep=config.max_checkpoints_to_keep,
)

workdir = os.path.join(os.getcwd(), "velocity")


### Training the rectified flow model.

Define the dataloders and train the rectified flow model.
This step should take around 30 mins in a `TPU v6e`.

In [None]:
train_dataloader = get_mnist_dataset(split="train", batch_size=config.batch_size)
eval_dataloader = get_mnist_dataset(split="test", batch_size=config.batch_size)

# We avoid using the evaluation so the training runs a bit faster.
train.run(
    train_dataloader=train_dataloader,
    trainer=trainer,
    workdir=workdir,
    total_train_steps=config.num_train_steps,
    metric_aggregation_steps=config.metric_aggregation_steps,  # 30
    num_batches_per_eval=config.num_batches_per_eval,
    metric_writer=metric_writers.create_default_writer(workdir,
                                                      asynchronous=False),
    callbacks=(
        # This callback shows a progress bar.
        callbacks.TqdmProgressBar(
            total_train_steps=config.num_train_steps,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically.
        callbacks.TrainStateCheckpoint(
            base_dir=workdir,
            options=ckpt_options,
        ),
    ),
)

## Testing the trained model

Here we consider the model above and we sample from it.

### Sampling from the trained model

In [None]:
# Loads the weights.
trained_flow_state = trainers.TrainState.restore_from_orbax_ckpt(
    f"{workdir}/checkpoints", step=None
)

# Extracts a batch, and gets the conditioning.
num_samples: int = 32
batch = next(iter(eval_dataloader))
cond = {'emb:label': batch['emb:label'][:num_samples]}

# Defines the dynamics.
dynamics_fn = ode_solvers.nn_module_to_dynamics(
      model.flow_model,
      autonomous=False,
      is_training=False,
      cond=cond,
  )

num_sampling_steps = 128

integrator = ode_solvers.RungeKutta4()
integrate_fn = functools.partial(
    integrator,
    dynamics_fn,
    tspan=jnp.arange(0.0, 1.0, 1.0 / num_sampling_steps),
    params=trained_flow_state.model_variables,
)

Here we solve the equation:

$$\dot{x_t} = v_{\theta} (x_t, t),$$
with inition condition given by $x_0 \sim N(0, I_d)$., and where the terminal condition $x_1$ corresponds to our samples.

In [None]:
samples_flow = integrate_fn(batch["x_0"][:num_samples])

### Plotting the generated samples.

Here we plot the samples. Consider that the network is quite small and it hasn't been exhaustively trained, so the quality of the samples can be further improved.

In [None]:
num_cols = 6
plt.figure(figsize=(num_cols*6, 5))
for i in range(1,num_cols+1):
    plt.subplot(1, num_cols, i)
    plt.imshow(samples_flow[-1, i-1, :, :, 0])
    plt.yticks(ticks=[], labels=[])
    plt.xticks(ticks=[], labels=[])
    plt.title(f"Label: {cond['emb:label'][i-1]}", fontsize=16)
plt.show()

## Setting up the flow-map model for distillation.

Similarly to above we construct the flow-map model. In this case we use a very similar model although, with some differences to be adapted to the

In [None]:
flow_map_nn_model = flow_map_models.RescaledFlowMapUNet(
    time_rescale=1.0,
    out_channels=1,
    num_channels=config.num_channels,
    downsample_ratio=config.downsample_ratio,
    num_blocks=config.num_blocks,
    noise_embed_dim=config.noise_embed_dim,
    padding=config.padding,
    use_attention=config.use_attention,
    use_position_encoding=config.use_position_encoding,
    num_heads=32,
    frequency_scaling="exponential",
    cond_embed_fn=backbones.MergeCategoricalEmbCond,
    cond_embed_kwargs={
        "cond_key": "emb:label",
        "num_classes": 10,
        "features_embedding": config.noise_embed_dim,
        }
)

Here we define the wrapper for the conditional flow map model. Given that this model is supposed to perform distillation it also requires the teacher model (flow_model) and its weights so it can be evaluated.

In [None]:
flow_map_model = flow_map_models.ConditionalLagrangianFlowMapModel(
    input_shape=config.input_shapes[0][1:],
    cond_shape={'emb:label': ()}, # This is a scalar.
    flow_model=flow_model,  # Teacher model.
    flow_map_model=flow_map_nn_model,  # Student model.
    params_flow=trained_flow_state.model_variables["params"], # Parameteres of the teacher model
    interpolant=interpolants.RectifiedFlow(),  # Defines the type of stochastic interpolant.
)

In [None]:
# Defining experiments through the config file.
schedule = optax.warmup_cosine_decay_schedule(
    init_value=1e-6,
    peak_value=2e-4,
    warmup_steps=5_000,
    decay_steps=config.num_train_steps_flow_map,
    end_value=1e-6,
)

optimizer_flow_map = optax.chain(
    optax.clip(config.clip),
    optax.adam(
        learning_rate=schedule,
        b1=config.beta1,
    ),
)

# If you have multidevice host, it will automatically distribute the training.
flow_map_trainer = flow_map_trainers.DistributedLagrangianFlowMapTrainer(
    model=flow_map_model,
    rng=jax.random.PRNGKey(config.seed),
    optimizer=optimizer_flow_map,
    ema_decay=config.ema_decay,
)

# Setting up checkpointing.
ckpt_options_flow_map = checkpoint.CheckpointManagerOptions(
    save_interval_steps=config.save_interval_steps,
    max_to_keep=config.max_checkpoints_to_keep,
)

workdir_flow = os.path.join(os.getcwd(), "flow_map")


Use this command to erase any previous checkpoint that you may already have.

In [None]:
!rm -R -f /content/flow_map

Running the traning loop. In this case most we will just load the pre-trained weights from the working directory.

This should take roughly 2 hours in a `TPU v6e`.

In [None]:
# Here we use a smaller batch size as the computational burden is higher.
train_dataloader = get_mnist_dataset(split="train",
                                     batch_size=config.batch_size_flow_map)

# Here we don't run evaluation to render the training a bit faster.
train.run(
    train_dataloader=train_dataloader,
    trainer=flow_map_trainer,
    workdir=workdir_flow,
    total_train_steps=config.num_train_steps_flow_map,
    metric_aggregation_steps=config.metric_aggregation_steps,
    num_batches_per_eval=config.num_batches_per_eval,
    metric_writer=metric_writers.create_default_writer(workdir_flow,
                                                       asynchronous=False),
    callbacks=(
        callbacks.TqdmProgressBar(
            total_train_steps=config.num_train_steps_flow_map,
            train_monitors=("train_loss",),
        ),
        # This callback saves model checkpoint periodically.
        callbacks.TrainStateCheckpoint(
            base_dir=workdir_flow,
            options=ckpt_options_flow_map,
        ),
    ),
)

## Sampling using the distilled model.

In [None]:
trained_flow_map_state = trainers.TrainState.restore_from_orbax_ckpt(
    f"{workdir_flow}/checkpoints", step=None
)
inference_fn = flow_map_model.inference_fn(trained_flow_map_state.model_variables, flow_map_model.flow_map_model )

In [None]:
# We test the one step generation starting from noise.
samples_flow_map = inference_fn(batch['x_0'][:num_samples],
                   jnp.ones((32,)),
                   jnp.zeros((32,)),
                   {"emb:label": batch["emb:label"][:num_samples]},)

We plot the samples (using the same labels an initial noise as the ones solved using the ODE).

In [None]:
samples_flow_map.shape

In [None]:
num_cols = 6
plt.figure(figsize=(num_cols * 6, 5))
for i in range(1,num_cols+1):
    plt.subplot(1, num_cols, i)
    plt.imshow(samples_flow_map[i-1, :, :, 0])
    plt.yticks(ticks=[], labels=[])
    plt.xticks(ticks=[], labels=[])
    plt.title(f"Label: {cond['emb:label'][i-1]}", fontsize=16)
plt.show()

### Few-shots generation.

Here we check the quality of the generation as we increase the number of applications of the network. Due to the underlying Markovian property of the ODE, we have that $X^{1, 0}(x_0)$ can be writen as

$$ X^{1, 0}(x_0) = X^{1, t_{n-1}} \circ X^{t_{n-1}, t_{n-2}}  \dots  \circ X^{t_1, 0}(x_0),$$

for a given partition of $[0,1]$ into intervals $\{ [t_{n}, t_{n+1}]\}_{i=0}^{n-1}$, where $0 = t_0 < t_{1} < t_{2} ... < t_{n-1} < t_{n} = 1.$

Following the main paper, usually having more partitions helps with the quality of the generated samples.

In [None]:
flow_map_samples_dict = {}
cond_test = {"emb:label": batch["emb:label"][:num_samples]}
def body_for_loop(i, x, cond, delta_t):
  return inference_fn(x,
                      delta_t * (i + 1) * jnp.ones((x.shape[0],)),
                      delta_t * i * jnp.ones((x.shape[0],)),
                      cond
                      )

number_of_eval_steps: tuple[int, ...] = (1, 2, 4, 8, 16, 32)

for num_steps in number_of_eval_steps:
  delta_t = 1./ num_steps
  body_for_loop = functools.partial(body_for_loop,
                                    delta_t=delta_t,
                                    cond=cond_test)
  samples = jax.lax.fori_loop(
      0, num_steps, body_for_loop, batch["x_0"][:num_samples]
  )
  flow_map_samples_dict[f"step_{num_steps}"] = samples


### Utility function to show the samples in a grid.

In [None]:
def plot_samples_grid(eval_samples: dict, cond: dict, num_cols:int =12 )-> None:
  """Plots the samples in the eval_samples dict by label."""
  num_rows = len(eval_samples)
  plt.figure(figsize=(num_cols*6, 5*num_rows))
  for i in range(1,num_cols+1):
    for j, (key, sample) in enumerate(eval_samples.items()):
      plt.subplot(num_rows, num_cols, i + j*num_cols)
      plt.imshow(sample[i-1, :, :, 0])
      plt.yticks(ticks=[], labels=[])
      plt.xticks(ticks=[], labels=[])
      if i == 1:
        plt.ylabel(key, fontsize=16)
      if j== 0:
        plt.title(f"label: {cond['emb:label'][i-1]}", fontsize=16)

  plt.show()

### Plotting the samples on a grid, with different labels and different number of applications.

In [None]:
plot_samples_grid(flow_map_samples_dict,
                  cond_test,
                  num_cols=16)