In [None]:
# Install deepinv (skip if already installed)
%pip install deepinv

<!-- MathJax macro definitions inserted automatically -->
$$
\newcommand{\forw}[1]{{A\left({#1}\right)}}
\newcommand{\noise}[1]{{N\left({#1}\right)}}
\newcommand{\inverse}[1]{{R\left({#1}\right)}}
\newcommand{\inversef}[2]{{R\left({#1},{#2}\right)}}
\newcommand{\inversename}{R}
\newcommand{\reg}[1]{{g_\sigma\left({#1}\right)}}
\newcommand{\regname}{g_\sigma}
\newcommand{\sensor}[1]{{\eta\left({#1}\right)}}
\newcommand{\datafid}[2]{{f\left({#1},{#2}\right)}}
\newcommand{\datafidname}{f}
\newcommand{\distance}[2]{{d\left({#1},{#2}\right)}}
\newcommand{\distancename}{d}
\newcommand{\denoiser}[2]{{\operatorname{D}_{{#2}}\left({#1}\right)}}
\newcommand{\denoisername}{\operatorname{D}_{\sigma}}
\newcommand{\xset}{\mathcal{X}}
\newcommand{\yset}{\mathcal{Y}}
\newcommand{\group}{\mathcal{G}}
\newcommand{\metric}[2]{{d\left({#1},{#2}\right)}}
\newcommand{\loss}[1]{{\mathcal\left({#1}\right)}}
\newcommand{\conj}[1]{{\overline{#1}^{\top}}}
$$

# Flow-Matching for posterior sampling and unconditional generation

This demo shows you how to perform unconditional image generation and posterior sampling using Flow Matching (FM).

Flow matching consists in building a continuous transportation between a reference distribution $p_1$ which is easy to sample from (e.g., a Gaussian distribution) and the data distribution $p_0$.
Sampling is done by solving the following ordinary differential equation (ODE) defined by a time-dependent velocity field $v_\theta(x,t)$:

\begin{align}\frac{dx_t}{dt} = v_\theta(x_t,t), \quad x_0 \sim p_0 \quad t \in [0,1]\end{align}

The velocity field $v_\theta(x,t)$ is typically trained to approximate the conditional expectation:

\begin{align}v_\theta(x_t,t) \approx \mathbb{E}_{x_0 \sim p_0, x_1 \sim p_1}\Big[ \frac{d}{dt} x_t | x_t = a(t) x_0 + b(t) x_1 \Big]\end{align}

where $a(t)$ and $b(t)$ are interpolation coefficients such that $x_t$ interpolates between $x_0$ and $x_1$.
When the reference distribution $p_0$ is the standard Gaussian, the velocity field can be expressed as a function of a Gaussian denoiser $D(x, \sigma)$ as follows:

\begin{align}v_\theta(x_t,t) = - \frac{b'(t)}{b(t)} x_t + \frac{1}{2}\frac{a(t) b'(t) - a'(t) b(t)}{a(t) b(t)} \left(D\left(\frac{x_t}{a(t)}, \frac{b(t)}{a(t)} \right) - x_t\right)\end{align}

The most common choice of time schedulers is the linear schedule $a(t) = 1 - t$ and $b(t) = t$.

In this demo, we will show how to :

-  Perform unconditional generation using, instead of a trained denoiser, the closed-form MMSE denoiser

\begin{align}D(x, \sigma) = \mathbb{E}_{x_0 \sim p_{data}, \epsilon \sim \mathcal{N}(0, I)} \Big[ x_0 | x = x_0 + \sigma \epsilon \Big]\end{align}

Given a dataset of clean images, it can be computed by evaluating the distance between the input image and all the points of the dataset (see [`deepinv.models.MMSE`](https://deepinv.github.io/deepinv/api/stubs/deepinv.models.MMSE.html)).

-  Perform posterior sampling using Flow-Matching combined with a DPS data fidelity term (see [`sampling/demo_diffusion_sde.py`](https://deepinv.github.io/deepinv/auto_examples/sampling/demo_diffusion_sde.html#sphx-glr-auto-examples-sampling-demo-diffusion-sde-py) for more details)

-  Explore different choices of time schedulers $a(t)$ and $b(t)$.

In [None]:
import torch
import deepinv as dinv
from deepinv.sampling import (
    PosteriorDiffusion,
    DPSDataFidelity,
    EulerSolver,
    FlowMatching,
)
import numpy as np
from torchvision import datasets, transforms
from deepinv.models import MMSE

-----------------------------

We start by working with the closed-form MMSE denoser.  It is calculated by computing the distance between the input image and all the points of the dataset.
This can be quite long to compute for large images and large datasets.  In this toy example, we use the validation set of MNIST.
When using this closed-form MMSE denoiser, the sampling is guaranteed to output an image of the dataset.

In [None]:
device = dinv.utils.get_device()
dtype = torch.float32

figsize = 2.5

# We use the closed-form MMSE denoiser defined using as atoms the testset of MNIST.
# The deepinv MMSE denoiser takes as input a dataloader.
dataset = datasets.MNIST(
    root=".", train=False, download=True, transform=transforms.ToTensor()
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=False)
n_max = (
    1000  # limit the number of images to speed up the computation of the MMSE denoiser
)
tensors = torch.cat([data[0] for data in iter(dataloader)], dim=0)  # (N,1,28,28)
tensors = tensors[:n_max].to(device)
denoiser = MMSE(dataloader=tensors, device=device, dtype=dtype)

---------------------------------------------------------------------

The FlowMatching module [`deepinv.sampling.FlowMatching`](https://deepinv.github.io/deepinv/api/stubs/deepinv.sampling.FlowMatching.html) uses by default the following schedules: $a_t=1-t$, $b_t=t$.
The module FlowMatching module takes as input the denoiser and the ODE solver.

In [None]:
num_steps = 100
timesteps = torch.linspace(0.99, 0.0, num_steps)
rng = torch.Generator(device).manual_seed(5)
solver = EulerSolver(timesteps=timesteps, rng=rng)
sde = FlowMatching(denoiser=denoiser, solver=solver, device=device, dtype=dtype)


sample, trajectory = sde(
    x_init=(1, 1, 28, 28),
    seed=0,
    get_trajectory=True,
)

dinv.utils.plot(
    sample,
    titles="Unconditional FM generation",
    save_fn="FM_sample.png",
    figsize=(figsize, figsize),
)

-----------------------------------------------------------------------

Now, we can use the Flow-Matching model to perform posterior sampling.
We consider the inpainting problem, where we have a masked image and we want to recover the original image.
We use DPS [`deepinv.sampling.DPSDataFidelity`](https://deepinv.github.io/deepinv/api/stubs/deepinv.sampling.DPSDataFidelity.html) as data fidelity term (see [`sampling/demo_diffusion_sde.py`](https://deepinv.github.io/deepinv/auto_examples/sampling/demo_diffusion_sde.html#sphx-glr-auto-examples-sampling-demo-diffusion-sde-py) for more details).
Note that due to the division by $a(t)$ in the velocity field, initialization close to t=1 causes instability.

In [None]:
x = next(iter(dataloader))[0][:1].to(device)

mask = torch.ones_like(x)
mask[..., 10:20, 10:20] = 0.0
physics = dinv.physics.Inpainting(
    img_size=x.shape[1:],
    mask=mask,
    device=device,
    noise_model=dinv.physics.GaussianNoise(sigma=0.1),
)
y = physics(x)
dps_fidelity = DPSDataFidelity(denoiser=denoiser, weight=1.0)
model = PosteriorDiffusion(
    data_fidelity=dps_fidelity,
    sde=sde,
    solver=solver,
    dtype=dtype,
    device=device,
    verbose=True,
)
x_hat, trajectory = model(
    y,
    physics,
    x_init=None,
    get_trajectory=True,
    seed=0,
)

# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
    [x, y, x_hat],
    show=True,
    titles=["Original", "Measurement", "Posterior sample"],
    figsize=(figsize * 3, figsize),
    save_fn="FM_posterior.png",
)

----------------------------------------------------------------

Finally, we show how to use different choices of time schedulers $a_t$ and $b_t$.
Here, we use another typical choice of schedulers $a_t = \cos(\frac{\pi}{2} t)$ and $b_t = \sin(\frac{\pi}{2} t)$ which also satisfy the interpolation condition $a_0 = 1$, $b_0 = 0$, $a_1 = 0$, $b_1 = 1$.
Note that, again, due to the division by $a_t$ in the velocity field, initialization close to t=1 causes instability.

In [None]:
a_t = lambda t: torch.cos(np.pi / 2 * t)
a_prime_t = lambda t: -np.pi / 2 * torch.sin(np.pi / 2 * t)
b_t = lambda t: torch.sin(np.pi / 2 * t)
b_prime_t = lambda t: np.pi / 2 * torch.cos(np.pi / 2 * t)

sde = FlowMatching(
    a_t=a_t,
    a_prime_t=a_prime_t,
    b_t=b_t,
    b_prime_t=b_prime_t,
    denoiser=denoiser,
    solver=solver,
    device=device,
    dtype=dtype,
)

model = PosteriorDiffusion(
    data_fidelity=dps_fidelity,
    sde=sde,
    solver=solver,
    dtype=dtype,
    device=device,
    verbose=True,
)

x_hat, trajectory = model(
    y,
    physics,
    x_init=None,
    get_trajectory=True,
)

# Here, we plot the original image, the measurement and the posterior sample
dinv.utils.plot(
    [x, y, x_hat],
    show=True,
    titles=["Original", "Measurement", "Posterior sample"],
    figsize=(figsize * 3, figsize),
    save_fn="FM_posterior_new_at_bt.png",
)