# Score-Based Generative Modeling

This repository is an implementation of [Song et al. (2020)](https://arxiv.org/pdf/2011.13456), and we try to reference equation numbers where possible. Code is built using JAX and [`equinox`](https://docs.kidger.site/equinox/). We credit the UNet architecture adapted from the [`sbgm` package](https://github.com/homerjed/sbgm), and adapt some of code therin for our implementation, as well as from [Patrick Kidger's tutorial](https://docs.kidger.site/equinox/examples/score_based_diffusion/). 

### Intro

Score-based generative modeling (SBGM) can be seen as a continuous-time reformulation and generalization of the original denoising diffusion probabilistic model (DDPM) paper. ([For an example implementation and tutorial of DDPM, see here](https://github.com/declanmcnamara/ddpm_mnist).)

In SBGM, data are "noised up" according to an SDE. It can be shown that as the SDE evolves in time, the distribution tends towards a standard Gaussian, regardless of the original data point chosen.

Generative modeling is performed by reversing this process through Anderson's theorem. We explain in more detail below.

#### Example: The Ornstein-Uhlenbeck SDE

Suppose we have
    \begin{equation*}
        d X_t  = m X_t dt + dW_t \ \ \ \textrm{[Ornstein-Uhlenbeck process]}
    \end{equation*}
with $m=-1$, where $W_t$ denotes a standard Brownian motion. Then
\begin{equation*}
    X_t = X_0 e^{-mt} + \int_0^t e^{-(t-s)} dW_s
\end{equation*}
for initial condition $X_0$.

See Example 6.8 of Karatzas and Shreve, \textit{Brownian Motion and Stochastic Calculus}}, for additional details.

For the same SDE above, if we have 
\begin{equation*}
        d X_t  = -X_t dt + dW_t \ \ \ \textrm{[Ornstein-Uhlenbeck process]}
\end{equation*}
then 
\begin{align*}
    X_t | X_0 \sim \mathcal{N}(X_0 e^{-t}, \frac{1}{2}(1-e^{-2t}))  \\
    \implies X_t | X_0 \overset{d}{\to} \mathcal{N}(0, 1/2)
\end{align*}
as $t \to \infty$.

This is just a 1-dimensional example, but illustrates convergence to a chosen reference distribution. In the example above, the reference is $N(0, 1/2)$ instead of the standard Gaussian, but this can be resolved with rescalings.

#### Anderson's Theorem

Let $X_0 \sim p_0$ and thereafter evolve according to the SDE
\begin{equation*}
    dX_t = f(X_t, t) dt + g(t) dW_t.
\end{equation*}

Fix $T > 0$ and let $p_t$ denote the density of $X_t$ as defined above. Then with $U_0 \sim p_T$ and
\begin{equation*}
    dU_t = -\left[ f(U_t, T-t) - g^2(T-t) \nabla_x \log p_{T-t}(U_t) \right] dt + g(T-t) d\tilde{W}_t,
\end{equation*}
we have $X_t \overset{d}{=} U_{T-t}$.

    

Anderson's theorem allows us to perform generative modeling provided we can approximate two quantities well:

1. Sampling $U_0 \sim p_T$
2. Computing the score function $\nabla_x \log p_t(x)$ for any $t,x$.

These cannot be done exactly. For #1, we will only have $p_T \approx \mathcal{N}(0, I)$, and we sample $U_0$ from this standard Gaussian as if the equality holds exactly. This is one form of approximation error. For #2, the score of the marginal distributions $\nabla_x \log p_t(x)$ are generally unknown, and must be approximated. Note that the conditional score $\nabla_x \log p_{t \mid x_0}(x_t \mid x_0)$ may be easy to compute (see the Ornstein-Uhlenbeck example above), but the unconditional (marginal) score function is the target to apply Anderson's theorem.

### Method

SBGM (Song et al., 2020) fits a neural network to approximate the unconditional score. The network function $s_\theta$ takes it two arguments, $x$ and $t$, and returns an approximation to $\nabla_x \log p_t(x)$.

The objective function is (Eq. 7 of Song et al.)

$$
\mathbb{E}_t \left[\lambda(t) \mathbb{E}_{x_0}\mathbb{E}_{x_t \mid x_0} \left(||s_\theta(x_t, t) - \nabla_{x_t} p(x_t \mid x_0)||_2^2 \right)  \right],
$$

where

- $\mathbb{E}_t$ is a an expectation over $t \sim \mathrm{Unif}[0,T]$
- $\lambda(t)$ is a weighting function (counterintuitively, this really implies that a Uniform [0,T] over $t$ is not the correct/ideal distribution to average over -- beyond our scope).
- The expectation $\mathbb{E}_{x_0}\mathbb{E}_{x_t \mid x_0}$ is approximated by Monte Carlo ancestral sampling of $x_0 \sim p_{\textrm{data}}$, then $x_t \mid x_0$ according to the forward (noising) SDE.
- The conditional score $\nabla_{x_t} p(x_t \mid x_0)$ is computed analytically.

Note that $s_\theta(x_t, t)$ does not receive knowledge about $x_0$; it must be fit to learn the unconditional score.


### Algorithm

Items that must be fixed a priori:

1. The choice of the forward (noising) SDE, with a known limiting reference distribution.
2. Weighting function $\lambda(t)$.
3. Network architecture $s_\theta(x, t)$ that takes in data points and time values.


Thereafter, the algorithm can proceed as:

1. Draw a data point $x_0$.
2. Draw a time $t \sim \mathrm{Unif}[0,T]$. 
3. Sample $x_t | x_0$ by simulating the forward SDE trajectory (ideally analytically rather than by Euler-Maruyama discretization, e.g. see Ornstein-Uhlenbeck above).
4. Compute score $\nabla_{x_t} p(x_t \mid x_0)$ of the draw.
5. Compute predicted score $s_\theta(x_t, t)$.
6. Compute gradient of $\lambda(t) \cdot ||s_\theta(x_t, t) - \nabla_{x_t} p(x_t \mid x_0)||_2^2$ w/r/t $\theta$.
7. Update $\theta$ along the negative gradient direction. 

Of course, the algorithm above can be performed across batches of data points. Note the two types of scores/gradients floating around: the score function $\nabla_x \log p_t(x)$ is a part of the objective function in $\theta$, the network parameters. 

### Implementation - Forward SDE + Training

First, we define a `ForwardSDE` module. 

We try to use notation that mirrors that from Song et al. (2020), and reference equations where possible. The forward SDE we implement is the variance preserved (VP) SDE (Eq. 25 of Song et al.). Where $\beta$ is referenced in code, it is meant to reference this notation. The SDE is

$$
dx_t = -\frac{1}{2} \beta(t) x_t dt + \sqrt{\beta(t)} dW_t
$$

This SDE can be shown to tend toward the standard Gaussian as $t$ grows large (Eq. 29). Below, `beta_module` defines a function $\beta(t)$ that defines the forward process above allows for evaluation of the function $\beta(t)$ and its integral efficiently.

In [None]:
import math
from abc import ABC, abstractmethod
from functools import partial
from typing import Callable, Optional, Self, Sequence, Tuple, Union

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from jaxtyping import Array, Key

class ForwardSDE(eqx.Module):
    betas: Array
    beta_module: eqx.Module
    dt: float

    def __init__(self, betas: Array, dt: float = 0.01):
        """
        Construct an SDE.
        """
        super().__init__()
        self.dt = dt
        self.betas = betas
        self.beta_module = (
            BetaIntegralDefinedScheduler()
        )

    def forward_dist(self, t, x0):
        """Return the distribution of x_t conditional on x_0.

        Form of Eq. (29) of SBGM.
        """
        beta_int = self.beta_module.beta_int(t)
        mu = x0 * jnp.exp((-1 / 2) * beta_int)
        sig_sq = 1.0 - jnp.exp((-1 * beta_int))
        scale = jnp.sqrt(sig_sq)
        xt_dist = dist.Normal(mu, scale)
        return xt_dist

    def forward_sample(self, t, x0, key):
        xt_dist = self.forward_dist(t, x0)
        return xt_dist.sample(key)

    def forward_sample_rparam(self, t, x0, key):
        beta_int = self.beta_module.beta_int(t)
        mu = x0 * jnp.exp((-1 / 2) * beta_int)
        sig_sq = 1.0 - jnp.exp((-1 * beta_int))
        scale = jnp.sqrt(sig_sq)
        epsilon = jr.normal(key, x0.shape)
        return mu, scale, epsilon

    def marginal_log_prob(self, xt, t, x0):
        xt_dist = self.forward_dist(t, x0)
        return xt_dist.log_prob(xt).sum()

    def f(self, x, t):
        """Define quantity $f(x, t)$ in Eq. (15) of SBGM."""
        return (-1 / 2) * self.beta_module.beta(t) * x

    def G(self, x, t):
        """Define quantity $G(x,t)$ in Eq. (15) of SBGM.

        NOTE: We do not explicitly return G(x,t) as a matrix.
        This allows us to simply keep images in their original shape for
        ease rather than dealing with reshaping, matrix multiplication, etc.
        """
        value = math.sqrt(self.beta_module.beta(t))
        return value


We provide equation number references where useful in the comments. The key functions to note are:

- `forward_dist`, returns the distribution of $x_t$ given $x_0$ and a time $t$.
- `marginal_log_prob`, returns $p_t(x_t \mid x_0)$ given all of these quantities. 

Below, `marginal_log_prob` will be differentiated to get the score.

In [None]:
@eqx.filter_jit
def score(xt, t, x0, forward_sde: ForwardSDE):
    wrapped_grad_fn = eqx.filter_grad(forward_sde.marginal_log_prob)
    return wrapped_grad_fn(xt, t, x0)

The `eqx.filter_` notation conveniently discards any non-arrays for JIT compilation. The `score` function thus simply returns `jax.grad` applied to the `marginal_log_prob` function above, giving us $\nabla_x \log p_t(x \mid x_0)$ as desired.

In [None]:
@eqx.filter_jit
def sample_time(key, t0: float, t1: float, n_sample: int):
    t = jr.uniform(key, (n_sample,), minval=t0, maxval=t1 / n_sample)
    t = t + (t1 / n_sample) * jnp.arange(n_sample)
    return t


@eqx.filter_jit
def my_single_loss_fn(score_model, t, x0, forward_sde: ForwardSDE, key):
    mu, scale, eps = forward_sde.forward_sample_rparam(t, x0, key)
    xt_draw = mu + scale * eps
    pred_score = score_model(t, xt_draw, key=key)
    actual_score = score(xt_draw, t, x0, forward_sde)
    weight = lambda t: 1 - jnp.exp(-forward_sde.beta_module.beta_int(t))
    return weight(t) * jnp.square(pred_score - actual_score).sum()


@eqx.filter_jit
def my_batched_loss_fn(score_model, batch_x0, forward_sde, key):
    time_key, sde_key = jr.split(key)
    batch_size = batch_x0.shape[0]
    t = sample_time(time_key, 0.0, 10.0, batch_size)
    part_x0_t = partial(
        my_single_loss_fn, score_model=score_model, forward_sde=forward_sde, key=sde_key
    )
    return jax.vmap(part_x0_t)(x0=batch_x0, t=t).mean()

The functions above are standard, some code is adapted from the tutorials referenced above. The weight function $\lambda(t)$ depends on some attributes of the forward SDE, but can be altered. 

The function `my_single_loss_fn` implements the algorithm outlined above. 

The function `my_batched_loss_fn` simply wraps the above to support batching via `vmap`. For this problem, we observed that we have set $[0,T]= [0,10]$ -- we had to experiment to get $T$ sufficiently large enough that $x_T$ is approximately distributed as standard Gaussian.

In [None]:
@eqx.filter_jit
def make_step(score_model, batch_x0, forward_sde, optimizer, opt_state, key):
    time_key, loss_key = jr.split(key)
    score_model = eqx.nn.inference_mode(score_model, False)

    loss_value, grads = eqx.filter_value_and_grad(my_batched_loss_fn)(
        score_model, batch_x0, forward_sde, loss_key
    )
    updates, opt_state = optimizer.update(
        grads, opt_state, eqx.filter(score_model, eqx.is_array)
    )
    score_model = eqx.apply_updates(score_model, updates)
    key = jr.split(time_key, 1)[0]  # new key
    return score_model, opt_state, loss_value, key

Lastly, `make_step` performs an entire step of the training loop. Notice we now take a second gradient, that of the batched loss function, with respect to the parameters of the `score_model`, and take a gradient step.

### Implementation - Reverse SDE + Sampling

The Reverse SDE class is implemented to allow sampling. It samples using Euler-Maruyama sampling of the reverse SDE using the already-trained score network to compute the drift term.

In [None]:
class ReverseSDE(eqx.Module):
    dt: float
    score_model: eqx.Module
    forward_sde: eqx.Module
    base_dist: numpyro.distributions.Distribution
    shape: tuple
    epsilon: float

    def __init__(self, dt, score_model, forward_sde):
        """
        Construct an SDE.
        """
        super().__init__()
        self.dt = dt
        self.score_model = score_model
        self.forward_sde = forward_sde
        self.shape = (1, 28, 28)
        self.base_dist = dist.Normal(loc=jnp.zeros(self.shape))
        self.epsilon = 1e-5

    def tilde_f(self, x, t):
        """The reverse SDE drift coefficient.

        Defined in terms of x, t as well as
        f, G from the forward SDE, as well as the
        the score function (replaced by a trained model
        of the score function here.)
        """
        fxt = self.forward_sde.f(x, t)
        gxt = self.forward_sde.G(x, t)  # scalar
        score = self.score_model(t, x)
        return fxt - (gxt**2) * score

    def tilde_G(self, x, t):
        """Reverse time scale coefficient; same as forward."""
        return self.forward_sde.G(x, t)

    def sample(self, key):
        key = jr.split(key)[0]
        x1 = self.base_dist.sample(key=key)
        time_grid = jnp.arange(
            start=10.0, stop=self.epsilon - self.dt, step=-1 * self.dt
        )

        curr_x = x1
        for time in time_grid:
            """
            Euler Maryuma iteration:
            dx <- [f(x, t) - g^2(t) * score(x, t, q)] * dt + g(t) * sqrt(dt) * eps_t
            x <- x + dx
            t <- t + dt
            """

            key = jr.split(key)[0]
            time = jnp.array(time)

            eps_t = jr.normal(key, self.shape)
            drift = self.tilde_f(curr_x, time)
            diffusion = self.tilde_G(curr_x, time)
            next_x_mean = curr_x - drift * self.dt  # mu_x = x + drift * -step

            curr_x = next_x_mean + diffusion * jnp.sqrt(self.dt) * eps_t

        return next_x_mean

    def sample_K(self, K, key):
        key = jr.split(key)[0]
        x1s = self.base_dist.sample(key, sample_shape=(K,))
        time_grid = jnp.arange(start=10.0, stop=self.epsilon, step=-1 * self.dt)
        curr_xs = x1s
        for time in time_grid:
            key = jr.split(key)[0]
            time = jnp.array(time)

            drift = partial(self.tilde_f, t=jnp.array(time))
            noise = partial(self.tilde_G, t=jnp.array(time))
            batch_drift = jax.vmap(drift)
            batch_noise = jax.vmap(noise)

            eps_t = jr.normal(key, x1s.shape)
            drifts = batch_drift(curr_xs)
            diffusions = batch_noise(curr_xs)
            next_xs_means = curr_xs - drifts * self.dt  # mu_x = x + drift * -step

            curr_xs = next_xs_means + diffusions[0] * jnp.sqrt(self.dt) * eps_t  # HACK

        return next_xs_means

    def plot_grid(self, nrow, ncol, key):
        K = nrow * ncol
        out = self.sample_K(K=K, key=key)

        fig, ax = plt.subplots(nrow, ncol)
        for i in range(nrow):
            for j in range(ncol):
                entry = ncol * i + j
                this_image = out[entry][0]
                ax[i, j].imshow(this_image)
        plt.savefig("example_fig.png")
        plt.clf()


### Full Training Loop

We wrap everything into a training loop. Periodically, we use the reverse SDE to sample digits from noise to see how well we're doing.

The code below sets up the experiment -- sets seeds, constructs a dataloader from the data, and instantiates the score model. 

In [None]:
from hydra import compose, initialize
import os
import hydra
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
import torch
import tqdm
import sys
sys.path.append("../")
from network import UNet
from samplers import ForwardSDE, ReverseSDE
from utils import (
    MNISTDataLoader,
    load_model,
    load_opt_state,
    save_model,
    save_opt_state,
)

with initialize(version_base=None, config_path="../conf"):
    cfg = compose(config_name="mnist")
    
seed = cfg.seed

if torch.cuda.is_available():
    device = cfg.training.device
    os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
else:
    device = "cpu"

key = jr.key(seed)
model_key, loader_key, train_key, sample_key = jr.split(key, 4)
train_loader = MNISTDataLoader(key=loader_key, batch_size=cfg.training.batch_size)

score_model = UNet(
    data_shape=(1, 28, 28),
    is_biggan=cfg.model.is_biggan,
    dim_mults=cfg.model.dim_mults,
    hidden_size=cfg.model.hidden_size,
    heads=cfg.model.heads,
    dim_head=cfg.model.dim_head,
    dropout_rate=cfg.model.dropout_rate,
    num_res_blocks=cfg.model.num_res_blocks,
    attn_resolutions=cfg.model.attn_resolutions,
    final_activation=cfg.model.final_activation,
    q_dim=None,
    a_dim=None,
    key=model_key,
)

optimizer = hydra.utils.instantiate(cfg.optimizer)
opt_state = optimizer.init(eqx.filter(score_model, eqx.is_array))

Below, we construct the forward SDE module and begin the training process. The training process terminates after a number of steps specified in the config file. Every 10 epochs, we show some simulated digits.

In [None]:
beta_int_fcn = lambda t: t
forward_sde = ForwardSDE(beta_int_fcn)

counter = 0
n_per_epoch = train_loader.n_batch
epoch_losses = []
epoch_idx = 0
for batch_idx, (X_batch, y_batch) in enumerate(
    tqdm.tqdm(train_loader.as_generator(), total=cfg.training.n_steps)
):
    score_model, opt_state, loss, train_key = make_step(
        score_model, X_batch, forward_sde, optimizer, opt_state, train_key
    )
    epoch_losses.append(loss)
    counter += 1
    
    if counter % n_per_epoch == 0:
        
        # Logging
        avg_epoch_loss = jnp.array(epoch_losses).mean()
        print(f"Epoch {epoch_idx}: Avg. Loss {avg_epoch_loss}")
        epoch_losses = []
        
        # Log example image
        if epoch_idx % 10 == 0:
            score_model = eqx.nn.inference_mode(score_model, True)
            reverse_sde = ReverseSDE(1e-1, score_model, forward_sde)
            nrow = 5
            ncol = 5
            K = nrow * ncol
            out = reverse_sde.sample_K(K=K, key=sample_key)

            fig, ax = plt.subplots(nrow, ncol)
            for i in range(nrow):
                for j in range(ncol):
                    entry = ncol * i + j
                    this_image = out[entry][0]
                    ax[i, j].imshow(this_image)
            plt.show()
            plt.clf()

            save_model(score_model, "current_model.eqx")

            # Save optimiser state
            save_opt_state(
                optimizer,
                opt_state,
                i=epoch_idx * n_per_epoch,
                filename="current_opt_state",
            )
        epoch_idx += 1
            