# Introduction to score based generative modelling.

In this notebook, we provide an introduction to generative modelling via the **score** of a distribution. If $q_d$ is the density of the distribution of interest, the score of $q_d$ is $ \nabla \log q_d(x)\,. $

**Score based generative modelling** consists of estimating the score of a distribution of interest from a set of samples $\mathcal{D} = \{X_1, \cdots, X_n\}$. Therefore, we search, for a given parametric family $s_{\theta}(x)$ (i.e. a given network structure), the parameter

$$\theta \in \operatorname{argmin}_{\theta} \mathbb{E}_{X \sim q_d}[\|s_{\theta}(X) - \nabla \log q_d(X)\|^2] \,. $$

Of course, it is not possible to directly solve the optimization problem presented above, since we do not know the score of the distribution.
In this notebook, we will consider three approaches to learning the score of a distribution $p$ without knowing the density (or the score):

* [1] Hyvärinen, A. (2005). Estimation of Non-Normalized Statistical Models by Score Matching. Journal of Machine Learning Research, 6(24), 695–709. Retrieved from http://jmlr.org/papers/v6/hyvarinen05a.html

* [2] P. Vincent, "A Connection Between Score Matching and Denoising Autoencoders," in Neural Computation, vol. 23, no. 7, pp. 1661-1674, July 2011, doi: 10.1162/NECO_a_00142.

* [3] Song, Y., & Ermon, S. (2019). Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32.


In [1]:
# Imports
from jax import numpy as jnp, grad, random, vmap, value_and_grad, jit, jvp, jacfwd, devices
from jax.tree_util import Partial
from jax.lax import scan
import numpyro.distributions as dist
import flax.linen as nn
import matplotlib.pyplot as plt
import matplotlib.animation
import math
import optax
from tqdm.notebook import tqdm
import orbax.checkpoint

#Defining master key for jax
KEY = random.key(0)

%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
plt.rcParams['figure.dpi'] = 150
plt.rcParams['animation.embed_limit'] = 2**128
plt.ioff()

<contextlib.ExitStack at 0x7f2ccddc7dd0>

In [2]:
# Defining Noised version of original gaussian mixture. q_d corresponds to p_t(0).
def p_t(std_t):
    means = jnp.meshgrid(jnp.arange(-10., 15., 10), jnp.arange(-10., 15., 10))
    means = jnp.stack([m.flatten() for m in means], axis=1)
    covs = jnp.repeat(jnp.eye(2)[None]*(.1 + std_t**2), axis=0, repeats=means.shape[0])
    weights = random.uniform(random.key(42), (len(means),))
    return dist.MixtureSameFamily(component_distribution=dist.MultivariateNormal(means,
                                                                                 covariance_matrix=covs),
                                  mixing_distribution=dist.Categorical(weights))

# Helper function to get the score values on a mesh
def get_mesh_score(score_fn):
    X, Y = jnp.meshgrid(jnp.linspace(-15, 15), jnp.linspace(-15, 15))
    xs = jnp.stack([X.flatten(), Y.flatten()], axis=1)
    scores = score_fn(xs)
    return X, Y, scores


## Unadjusted Langevin Algorithm

In this notebook, we will sample from a learned score model using the Unadjusted Langevin Algorithm (ULA). We refer to the following papers for a presentation of the algorithm and it's performance.

* [4] Roberts, G. O., & Tweedie, R. L. (1996). Exponential convergence of Langevin distributions and their discrete approximations. Bernoulli, 341-363.
* [5] Durmus, A., Majewski, S., & Miasojedow, B. (2019). Analysis of Langevin Monte Carlo via convex optimization. The Journal of Machine Learning Research, 20(1), 2666-2711.

In [3]:
# ULA kernel
def ula(x, key, score_fun, learning_rate):
    noise = random.normal(key=key, shape=x.shape)
    return x + learning_rate * score_fun(x) + ((2*learning_rate)**.5)*noise

# Function that is basically a for loop over ULA step but return every sample.
# To better understand the syntax of Jax's scan function go to: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
def several_steps_ula(init, key, score_fun, learning_rate, n_steps):
    
    def my_ula(x, key):
        pred = ula(x, key, score_fun, learning_rate)
        return pred, pred
    return scan(f=my_ula,
                init=init,
                xs=random.split(key, n_steps))[-1]

# This is a helper function to allow for multi step ula.
def multiple_chain_ula(init, keys, n_steps, score_fun, learning_rate):
    return vmap(Partial(several_steps_ula, n_steps=n_steps, learning_rate=learning_rate, score_fun=score_fun))(init, keys)

In [4]:
# This cell initializes some generic global variables used throughout the notebook 
LANGEVIN_LR=1e-2
N_EPOCHS=3_000
BATCH_SIZE=2048
KEY, subkey = random.split(KEY, 2)
dataset = p_t(0).sample(subkey, (10_000,))

In [5]:
# Performs ULA using the real score
KEY, subkey_init, subkey_keys = random.split(KEY, 3)
samples_ula_0 = multiple_chain_ula(random.normal(subkey_init, shape=(20, 2))*10,
                                   random.split(subkey_keys, 20),
                                   learning_rate=LANGEVIN_LR,
                                   score_fun=jit(grad(lambda x: p_t(0).log_prob(x))),
                                   n_steps=1_000)

X, Y, scores = get_mesh_score(vmap(grad(p_t(0).log_prob)))

We consider $p$ to be a $2$-dimensional Gaussian mixture distribution.
Below we visualize the result of running ULA with the real score of $q_d$. We initialize the ULA chain with samples from $\mathcal{N}(0, 100 \operatorname{I})$.

In [6]:
# Animate ULA with the real score
KEY, subkey = random.split(KEY, 2)
samples = p_t(0).sample(subkey, (1000,))

fig, axes = plt.subplots(1, 2)

for ax in axes:
  ax.scatter(*samples[:1000].T, color='blue')
  ax.quiver(X, Y, scores[:, 0], scores[:, 1], color='black')
  ax.set_xlim(-15, 15)
  ax.set_ylim(-15, 15)
traj_artists, points_artists = [], []
for i, traj in enumerate(samples_ula_0[:, :1]):
    point_artist = axes[1].scatter(*traj[-1], color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)))
    traj_artist, = axes[0].plot(*traj.T, color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)))
    traj_artists.append(traj_artist)
    points_artists.append(point_artist)
artists = traj_artists + points_artists

def init():

  return artists


def update(frame):
    n_artists = len(artists) //2
    for i, traj in enumerate(samples_ula_0[:, :frame + 1]):
        artists[i].set_data(traj[:, 0], traj[:, 1])
        artists[i + n_artists].set_offsets(traj[-1])
    return artists

matplotlib.animation.FuncAnimation(fig,  update, frames=jnp.arange(0, 25, 1).tolist(),
                    init_func=init, blit=True)

KeyboardInterrupt: 

## Hyvärinen Approach [1]

We start by the Loss presented in [1]. By integration by parts, it is possible to show that the score matching objective is equivalent to:

$$ \theta \in \operatorname{argmin}_{\theta} \mathbb{E}_{X \sim q_d}[\nabla \cdot s_{\theta}(X) + \frac{1}{2} \|s_{\theta}(X)\|^2] \,. $$

We use a simple MLP (multi-layer perceptron) network to model the score $s_\theta$.

In [None]:
# This function defines a general training function used throught the paper
def train(dataset,
          learning_rate,
          net_params,
          loss_fn,
          key,
          n_epochs,
          batch_size = None,
         update_tqdm_period=100):
    tx = optax.adam(learning_rate=learning_rate)
    opt_state = tx.init(net_params)

    # define the heavy lifting function of the training loop
    def loop_fn_(sb_k, net_params, opt_state):
        sb_k_loss, sb_k_batch = random.split(sb_k, 2)
        if batch_size:
            batch_data = dataset[random.randint(key=sb_k_batch, maxval=dataset.shape[0], minval=0, shape=(batch_size,))]
        else:
            batch_data = dataset
        loss_val, grad = loss_fn(net_params, batch_data, sb_k_loss)
        updates, opt_state = tx.update(grad, opt_state)

        net_params = optax.apply_updates(net_params, updates)
        return loss_val, net_params, opt_state
    # Mark it to be just-in-time compiled
    loop_fn = jit(loop_fn_)


    pbar = tqdm(enumerate(random.split(key, n_epochs)))
    losses = []
    for i, sb_k in pbar:
        # Do stuff
        loss_val, net_params, opt_state = loop_fn(sb_k, net_params, opt_state)

        # Plot stuff
        if i % update_tqdm_period == 0:
            pbar.set_description(f"Loss: {loss_val:.2E}")
        losses.append(loss_val)
    return net_params, losses


# Loss from Hyvarinen paper.
def make_hyvarinen_loss(net, n_eps_hutch=100, data_dim=2, use_hutching_trick=False):
    if use_hutching_trick:
        # To be completed
        raise NotImplementedError
    else:
        def loss_fn(params, data, key):
            def score_twice(data):
                score = net.apply(params, data)
                return score, score
            jac, score = vmap(jacfwd(score_twice, has_aux=True))(data)

            div = jnp.diagonal(jac, axis1=-2, axis2=-1).sum(axis=-1)
            score_norm = (score**2).sum(axis=-1)

            return (div + .5 * score_norm).mean()
    return loss_fn

In [None]:
# Simple Neural Net from Flax
class MLP(nn.Module):              
  out_dims: int

  @nn.compact
  def __call__(self, x):
      x = nn.Dense(128)(x)
      x = nn.LayerNorm()(x)
      x = nn.leaky_relu(x)
      x = nn.Dense(512)(x)
      x = nn.LayerNorm()(x)
      x = nn.leaky_relu(x)
      x = nn.Dense(256)(x)
      x = nn.LayerNorm()(x)  
      x = nn.leaky_relu(x)
      x = nn.Dense(self.out_dims)(x)
      return nn.leaky_relu(x)

In [None]:
# Initialization
model = MLP(out_dims=2)
x = jnp.empty((1, 2))
net_params = model.init(random.key(42), x)

In [None]:
#Training
KEY, subkey = random.split(KEY, 2)
net_params, losses = train(
    dataset=dataset,
    learning_rate=1e-4,
    net_params=net_params,
    loss_fn=value_and_grad(make_hyvarinen_loss(model)),
    key=subkey,
    n_epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
)

In [None]:
# PLotting the loss
fig, ax = plt.subplots(1, 1)
ax.plot(jnp.stack(losses))
fig

### Sampling with Hyvarinen Model:

We now focus on sampling with the Hyvarinen model musing ULA with the same parameterization as for the case of perfect score, and the same starting distribution.

In [None]:
# Sampling with ULA
KEY, subkey_init, subkey_keys = random.split(KEY, 3)
samples_ula_0 = multiple_chain_ula(random.normal(subkey_init, shape=(20, 2))*10,
                                   random.split(subkey_keys, 20),
                                   learning_rate=LANGEVIN_LR,
                                   score_fun=jit(lambda x: model.apply(net_params, x)),
                                   n_steps=1_000)
X, Y, scores = get_mesh_score(lambda xs: vmap(model.apply, in_axes=(None, 0))(net_params, xs))

In [None]:
# Animating with ULA
fig, axes = plt.subplots(1, 2)

traj_artists, points_artists = [], []
for i, traj in enumerate(samples_ula_0[:, :1]):
    point_artist = axes[1].scatter(*traj[-1], color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)))
    traj_artist, = axes[0].plot(*traj.T, color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)))
    traj_artists.append(traj_artist)
    points_artists.append(point_artist)
artists = traj_artists + points_artists

def init():
  for ax in axes:
      ax.scatter(*dataset[:1000].T, color='blue')
      ax.quiver(X, Y, scores[:, 0], scores[:, 1], color='black')
      ax.set_xlim(-15, 15)
      ax.set_ylim(-15, 15)
  return artists


def update(frame):
    n_artists = len(artists) //2
    for i, traj in enumerate(samples_ula_0[:, :frame + 1]):
        artists[i].set_data(traj[:, 0], traj[:, 1])
        artists[i + n_artists].set_offsets(traj[-1])
    return artists

matplotlib.animation.FuncAnimation(fig,  update, frames=jnp.arange(0, 1_000, 50).tolist(),
                    init_func=init, blit=True)

## Denoising Loss [2]:

In [2], the authors propose learning the score of a noised version of $q_d$, that we denote by $q_t$. Formally,
$$q_t(x) = \mathbb{E}_{X \sim q_d}[q_{t|0}(x | X)] \, ,$$
where $q_{t|0}(x | x_0) = \mathcal{N}(x; x_{0}, \upsilon_t^2 \operatorname{I})$ with $\upsilon_t > 0$.

In this case, estimating the score of $q_t$ via the score mathing objective is equivalent to
$$ \operatorname{argmin}_{\theta}\mathbb{E}_{X \sim q_d, \epsilon \sim \mathcal{N}(0, \operatorname{I})} [\|s_{\theta}(X + \upsilon_t \epsilon) - \nabla \log q_{t|0}(X + \upsilon_t \epsilon|X)\|^2]  = \operatorname{argmin}_{\theta}\mathbb{E}_{X \sim q_d, \epsilon \sim \mathcal{N}(0, \operatorname{I})}[\|s_{\theta}(X + \upsilon_t \epsilon) + \upsilon_t^{-1}\epsilon \|^2]\, . $$

In [7]:
# Function that makes the denoising loss function
def make_denoising_loss(net, std=1e-2, data_dim=2):
    def loss_fn(params, data, key):
        eps = random.normal(key, (data.shape[0], data_dim))
        score = net.apply(params, data + std*eps)

        loss = ((score + eps/std)**2).sum(axis=-1)

        return loss.mean()
    return loss_fn

In [None]:
model = MLP(out_dims=2)
x = jnp.empty((1, 2))
net_params = model.init(random.key(42), x)

In [None]:
KEY, subkey = random.split(KEY, 2)
net_params, losses = train(
    dataset=dataset,
    learning_rate=1e-2,
    net_params=net_params,
    loss_fn=value_and_grad(make_denoising_loss(model, std=1)),
    key=subkey,
    n_epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    update_tqdm_period=1
)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(jnp.stack(losses))
fig

### Sampling from the learned model

In [None]:
KEY, subkey_init, subkey_keys = random.split(KEY, 3)
samples_ula_0 = multiple_chain_ula(random.normal(subkey_init, shape=(20, 2))*10,
                                   random.split(subkey_keys, 20),
                                   learning_rate=LANGEVIN_LR,
                                   score_fun=jit(lambda x: model.apply(net_params, x)),
                                   n_steps=1_000)

X, Y, scores = get_mesh_score(lambda xs: vmap(model.apply, in_axes=(None, 0))(net_params, xs))

In [None]:
fig, axes = plt.subplots(1, 2)

for ax in axes:
    ax.scatter(*dataset[:1000].T, color='blue', alpha=.6)
    ax.quiver(X, Y, scores[:, 0], scores[:, 1], color='black')
    ax.set_xlim(-15, 15)
    ax.set_ylim(-15, 15)

traj_artists, points_artists = [], []
for i, traj in enumerate(samples_ula_0[:, :1]):
    point_artist = axes[1].scatter(*traj[-1], color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)), alpha=.8)
    traj_artist, = axes[0].plot(*traj.T, color=plt.get_cmap("rainbow")(i / (samples_ula_0.shape[0]-1)), alpha=.8)
    traj_artists.append(traj_artist)
    points_artists.append(point_artist)
artists = traj_artists + points_artists

def init():

  return artists


def update(frame):
    n_artists = len(artists) //2
    for i, traj in enumerate(samples_ula_0[:, :frame + 1]):
        artists[i].set_data(traj[:, 0], traj[:, 1])
        artists[i + n_artists].set_offsets(traj[-1])
    return artists

matplotlib.animation.FuncAnimation(fig,  update, frames=jnp.arange(0, 1_000, 10).tolist(),
                    init_func=init, blit=True)

## Noise Conditional Score Networks (NCSN) [3]

We now consider the approach presented in [3], which consist of using a single network $s_\theta(x, \upsilon_t)$ that approaches jointly the score of a sequence of distributions $\{q_t\}_{t \in [\varepsilon, 1]}$ defined as above but for a sequence of $\{\upsilon_t\}_{t \in [\varepsilon, 1]}$. Here, we define $\upsilon_\varepsilon=0.02$, $\upsilon_1 = 10$ and $\upsilon_t = (\upsilon_\varepsilon^{1/\rho} + \frac{t - \varepsilon}{1-\varepsilon} (\upsilon_1^{1/\rho} - \upsilon_\varepsilon^{1/\rho}))^\rho$ with $\rho=5$.

To do so, we consider the following loss
$$ \operatorname{argmin}_{\theta}\mathbb{E}_{t \in \mathcal{U}(\varepsilon, 1)}\left[\gamma_t^2\mathbb{E}_{X \sim q_d, \epsilon \sim \mathcal{N}(0, \operatorname{I})}[\|s_{\theta}(X + \upsilon_t \epsilon) - \upsilon_t^{-1}\epsilon \|^2]\right]\, ,$$
where $\gamma_t = \upsilon_t$ is a weighting coefficient.
The goal of this sequence of distributions is to allow for efficient sequential ULA steps, since the distance between two consecutive distributions $q_{t_1}$ and $q_{t_2}$ is close for $t_1 \approx t_2$, thus using samples from $q_{t_2}$ to initialize ULA for $q_{t_1}$ would be a good initialization. Furthemore, this allows to approximate better $p$ by choosing $\varepsilon$ small.

We start by visualizing the sequence of distributions $\{q_t\}_{t \in [\varepsilon, 1]}$ defined above.

In [None]:
KEY, subkey = random.split(KEY, 2)
std_min = 0.02
std_max = 10
p=5
stds = (std_max ** (1/p) + jnp.linspace(1, 0, 100)*(std_min**(1/p) - std_max**(1/p)))**p #taken from Karras2022 (https://arxiv.org/pdf/2206.00364.pdf)
samples = vmap(lambda std_t, key: p_t(std_t).sample(key, sample_shape=(1_000,)))(stds, random.split(subkey, 100))

In [None]:
fig, ax = plt.subplots()
def animate(t):
    plt.cla()
    plt.scatter(*samples[t].T)
    plt.xlim(-15, 15)
    plt.ylim(-15, 15)

matplotlib.animation.FuncAnimation(fig, animate, frames=100)

In [None]:
def make_ncns_loss(net, std_min=1e-2, std_max=10, data_dim=2, n_time_samples=25):
    def loss_fn(params, data, key):
        key_uniform, key_noise = random.split(key, 2)
        #stds = random.uniform(key_uniform, minval=std_min, maxval=std_max, shape=(data.shape[0], n_time_samples, 1))
        stds = jnp.exp(random.normal(key_uniform, shape=(data.shape[0], n_time_samples, 1))*1.2 - .5).clip(std_min, std_max)
        eps = random.normal(key, (data.shape[0], n_time_samples, 2))
        x_t = data[:, None] + stds * eps
        score = net.apply(params, x_t, stds)

        loss = ((score + eps/stds[None])**2).sum(axis=-1)
        loss_weights = stds[..., 0]**2
        return  (loss_weights*loss).mean()
    return loss_fn

In [None]:
class MLP(nn.Module):                    
  out_dims: int

  @nn.compact
  def __call__(self, x, std):
    std_emb = .25 * jnp.log(std)
    x = nn.Dense(256)(jnp.concatenate((x, std_emb), axis=-1))
    x = nn.LayerNorm()(x)              
    x = nn.leaky_relu(x)
    x = nn.Dense(512)(x)
    x = nn.leaky_relu(x)
    x = nn.Dense(256)(x)   
    x = nn.LayerNorm()(x)               
    x = nn.leaky_relu(x)
    x = nn.Dense(self.out_dims)(x)       
    return nn.leaky_relu(x)

model = MLP(out_dims=2)
x, std = jnp.empty((1, 2)), jnp.empty((1, 1))
net_params = model.init(random.key(42), x, std)

In [None]:
KEY, subkey = random.split(KEY, 2)
net_params, losses = train(
    dataset=dataset,
    learning_rate=1e-2,
    net_params=net_params,
    loss_fn=value_and_grad(make_ncns_loss(model, std_min=std_min, std_max=std_max, n_time_samples=2)),
    key=subkey,
    n_epochs=N_EPOCHS,
    batch_size=BATCH_SIZE,
    update_tqdm_period=1
)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.plot(jnp.stack(losses))
fig

## Sampling with NCSN

We now consider doing sequentially ULA for each $q_t$, starting from $q_{100}$.

In [None]:
def sequential_ula(
        initial_samples,
        key,
        stds_sequence,
        n_iter_per_timestep,
        score_fun,
        learning_rate_ratio,
        n_chains,
):
    def scan_fn(x, meta):
        key, lr, std = meta
        samples = multiple_chain_ula(x,
                                     random.split(key, n_chains),
                                     learning_rate=lr,
                                     score_fun=Partial(score_fun, std=std[None]),
                                     n_steps=n_iter_per_timestep)
        return samples[:, -1], samples
    return scan(f=scan_fn,
                init=initial_samples,
                xs=(random.split(key, len(stds_sequence)),
                    learning_rate_ratio * (stds_sequence / stds_sequence[-1])**2,
                    stds_sequence))[-1]

In [None]:
KEY, subkey_init = random.split(KEY, 2)
samples_ncnn = sequential_ula(
    random.normal(subkey_init, shape=(20, 2))*std_max,
    subkey_init,
    stds_sequence=stds[::-1],
    n_iter_per_timestep=5,
    score_fun=jit(lambda x, std: model.apply(net_params, x, std)),
    learning_rate_ratio=1e-4,
    n_chains=20
)
score_quivers = []
for std in stds[::-1]:
    X, Y, scores = get_mesh_score(lambda xs: vmap(model.apply, in_axes=(None, 0, None))(net_params, xs, std[None]))
    score_quivers.append(scores / jnp.abs(scores).max(axis=0)[None, :])

In [None]:
fig, axes = plt.subplots(1, 2)

for ax in axes:
    ax.scatter(*dataset[:1000].T, color='blue', alpha=.6)
    ax.quiver(X, Y, scores[:, 0], scores[:, 1], color='black')
    ax.set_xlim(-15, 15)
    ax.set_ylim(-15, 15)

traj_artists, points_artists, quiver_artists = [], [], []
for ax in axes:
    quiver_artists.append(ax.quiver(X, Y, score_quivers[0][:, 0], score_quivers[0][:, 1], color='black'))
for i, traj in enumerate(jnp.swapaxes(samples_ncnn[:1], 0, 1)):
    point_artist = axes[1].scatter(*traj[-1, -1], color=plt.get_cmap("rainbow")(i / (samples_ncnn.shape[1]-1)), alpha=.8)
    traj_artist, = axes[0].plot(*traj.reshape(-1, 2).T, color=plt.get_cmap("rainbow")(i / (samples_ncnn.shape[1]-1)), alpha=.5)
    traj_artists.append(traj_artist)
    points_artists.append(point_artist)
artists = traj_artists + points_artists + quiver_artists

def init():

  return artists


def update(frame):
    n_artists = (len(artists) - 2) //2
    for i, traj in enumerate(jnp.swapaxes(samples_ncnn[max(frame-40, 0):frame + 1], 0, 1)):
        artists[i].set_data(traj[..., 0].reshape(-1), traj[..., 1].reshape(-1))
        artists[i + n_artists].set_offsets(traj[-1, -1])
    artists[-2].set_UVC(score_quivers[frame][:, 0], score_quivers[frame][:, 1])
    artists[-1].set_UVC(score_quivers[frame][:, 0], score_quivers[frame][:, 1])
    return artists

matplotlib.animation.FuncAnimation(fig,  update, frames=jnp.arange(0, 100, 2).tolist(),
                    init_func=init, blit=True)

## DDIM Sampling

Note that the sequence of distributions defined above match the marginals of the following Markov chain:

$$ X_t = X_{t-1} + (\upsilon_t^2 - \upsilon_{t-1}^2)^{1/2} \epsilon_t \,,$$

where $\epsilon_t \sim \mathcal{N}(0, \operatorname{I})$ for $X_0 \sim q_{d}$. We denote the Law of $X_t$ knowing $X_{t-1}$ as $q_{t|t-1}(x_t|x_{t-1}) = \mathcal{N}(x_t; x_{t-1}, (\upsilon_t^2 - \upsilon_{t-1}^2) \operatorname{I})$.

### Inference distribution

We now focus on the DDIM sampler [4]. The goal of DDIM is to propose a backward Markov chain that matches well the Markov chain defined above. To do so, it focus first on the Law of $X_{1:T}$ knowing $X_{0}$.
It relies on the inference distribution

$$ q_{1:T}^{\eta}(x_{1:T} | x_0) = q_{T|0} \prod_{i=2}^{T} q_{t-1 | t, 0}^{\eta_{t-1}}(x_{t-1}|x_t, x_0)\,,$$ 

where $\eta \in [0, \infty) \times (0, \upsilon_1) \times \cdots \times (0, \upsilon_{T-1})$ and 

$$  q_{t-1 | t, 0}^{\eta_{t-1}}(x_{t-1}|x_t, x_0) = \mathcal{N}(x_0 + [(\upsilon_{t-1}^2 - \eta_{t-1}^2)^{1/2} / \upsilon_{t}] (x_t - x_0), \eta_{t-1}^2 \operatorname{I}) \,. $$

The key property of the inference distribution is that it matches the laws of $X_t$ knowing $X_0$:

$$ q_{t|0} = \int q_{1:T}^{\eta}(x_{1:T} | x_0) dx_{1:t-1} dx_{t+1:T} = \mathcal{N}(x_0, \upsilon_{t}^2 \operatorname{I}) \,.$$

Indeed, for the *particular choice* of $\eta_t = (\upsilon_t^2 - \upsilon_{t-1}^2)^{1/2} (\upsilon_{t-1} / \upsilon_t)$ we retrieve the Bayes decomposition:

$$ q_{t-1 | t, 0}^{\eta_t} = \frac{q_{t|t-1}(x_t | x_{t-1}) q_{t-1|0}(x_{t-1}|x_{0})}{q_{t|0}(x_t|x_0)}$$

[4] Song, J., Meng, C., & Ermon, S. (2020, October). Denoising Diffusion Implicit Models. In International Conference on Learning Representations.

In [None]:
def inference_process(x_0, x_t, sigma_t, sigma_t_1, eta_t_1):
    eps = x_t - x_0
    coeff_eps = ((sigma_t_1**2 - eta_t_1**2) / (sigma_t**2))**.5
    return x_0 + coeff_eps * eps


def inference_process_sampling(x_T, key, x_0, etas, sigmas):
    def _infproc_step(x_t, meta):
        key, eta_t_1, sig_t, sig_t_1 = meta
        x_t_1 = inference_process(x_0, x_t, sig_t, sig_t_1, eta_t_1)
        x_t_1 = x_t_1 + eta_t_1 * random.normal(key, shape=x_t_1.shape)
        return x_t_1, x_t_1
    n_steps = len(etas)
    return scan(f=_infproc_step,
                init=x_T,
                xs=(random.split(key, n_steps), etas[::-1], sigmas[1:][::-1], sigmas[:-1][::-1]))[-1]


def make_inference_process_sampler(N, eps, p=5, std_min=0.02, std_max=10.):
    stds = (std_max ** (1/p) + jnp.linspace(1, 0, N)*(std_min**(1/p) - std_max**(1/p)))**p
    etas = ((stds[1:] ** 2 - stds[:-1] **2)**.5) * (stds[:-1] / stds[1:]) * eps
    return jit(vmap(Partial(inference_process_sampling,
                            etas=etas,
                            sigmas=stds), in_axes=(0, 0, None)))

In [None]:
bayes_inf_proc_sampler = make_inference_process_sampler(100, 1)
other_inf_proc_sampler = make_inference_process_sampler(100, 0.1)

In [None]:
KEY, subkey_init, subkey_bayes, subkey_2 = random.split(KEY, 4)
initial_samples = random.normal(subkey_init, shape=(20, 1))*std_max

bayes_samples = bayes_inf_proc_sampler(initial_samples, random.split(subkey_bayes, 20), jnp.zeros((1,)))
other_samples = other_inf_proc_sampler(initial_samples, random.split(subkey_2, 20), jnp.zeros((1,)))

In [None]:
n_steps = other_samples.shape[1]
range_to_plot = jnp.linspace(1, 0, n_steps)
fig, ax = plt.subplots()

ax.set_xlim(0, 1)
ax.set_ylim(-10, 10)
ax.set_yscale('symlog')
ax.fill_between(range_to_plot[::-1], -stds[:-1]*3, stds[:-1]*3, color='red', alpha=.3)

artists = [
    ax.plot([], [],  linestyle='dashed', color='blue', marker='o', markersize=3)[0], 
    ax.plot([], [],  linestyle='dashed', color='green', marker='o', markersize=3)[0]
]

def init():

  return artists

def update(t):

    artists[0].set_data(range_to_plot[:t+1], bayes_samples[0, :t+1])
    artists[1].set_data(range_to_plot[:t+1], other_samples[0, :t+1])
    return artists

matplotlib.animation.FuncAnimation(
    fig,
    update,
    frames=jnp.arange(0, 100, 1).tolist(),
    init_func=init, blit=True
)

### DDIM sampling

Then, to generate the backward chain from DDIM, one simply replace in every $q_{t-1|t, 0}$ the $X_0$ term by Tweedie's approximation of the mean obtained with the score net:

$$ \mu_{t, \theta}(x_t) = x_t + \upsilon_t^2 s_{\theta}(x_t, \upsilon_t)\,. $$

By replacing $q_{T|0}$ by $\lambda = \mathcal{N}(0, \upsilon_T^2 \operatorname{I})$ we obtain

$$p_{0:T}(x_{0:T}) = \lambda(x_T) \prod_{t=1}^{T} p_{t-1|t}(x_{t-1}| x_t) \,,$$

where $p_{t-1|t} = q_{t-1|t, 0}(x_{t-1}|x_{t}, \mu_{t, \theta}(x_t))$ for $t > 1$ and $p_{0|1} = \mathcal{N}(\mu_{1, \theta}(x_1), \eta_0^2 \operatorname{I})$.

In [None]:
def ddim_sampling(x_T, key, etas, sigmas, score_net):
    
    def _ddim_step(x_t, meta):
        key, eta_t_1, sig_t, sig_t_1 = meta
        pred_x_0 = x_t + (sig_t**2) * score_net(x_t, sig_t)
        x_t_1 = inference_process(pred_x_0, x_t, sig_t, sig_t_1, eta_t_1)
        x_t_1 = x_t_1 + eta_t_1 * random.normal(key, shape=x_t_1.shape)
        return x_t_1, x_t_1
        
    n_steps = len(etas) - 1
    x_1, x_traj = scan(f=_ddim_step,
                init=x_T,
                xs=(random.split(key, n_steps), etas[1:][::-1], sigmas[1:][::-1], sigmas[:-1][::-1]))
    x_0 = x_1 + (sigmas[0]**2) * score_net(x_1, sigmas[0])
    x_0 = x_0 + etas[0] * random.normal(key, shape=x_0.shape) 
    return jnp.concatenate((x_traj, x_0[None]), axis=0)


def make_ddim_sampler(N, eps, score_net, eta_0=0., p=5, std_min=0.02, std_max=10.):
    stds = (std_max ** (1/p) + jnp.linspace(1, 0, N)*(std_min**(1/p) - std_max**(1/p)))**p
    etas = jnp.clip(((stds[1:] ** 2 - stds[:-1] **2)**.5) * (stds[:-1] / stds[1:]) * eps, 0, stds[:-1] - 1e-8)
    etas = jnp.concatenate((jnp.array([eta_0]), etas), axis=0)
    return jit(vmap(Partial(ddim_sampling,
                            etas=etas,
                            sigmas=stds,
                           score_net=score_net), in_axes=(0, 0)))

In [None]:
bayes_ddim_sampler = make_ddim_sampler(50, 1, score_net=lambda x, std: model.apply(net_params, x, std[None]))
other_ddim_sampler = make_ddim_sampler(50, 0.2, score_net=lambda x, std: model.apply(net_params, x, std[None]))

In [None]:
KEY, subkey_init, subkey_bayes, subkey_other = random.split(KEY, 4)
initial_samples = random.normal(subkey_init, shape=(100, 2))*std_max

bayes_samples = bayes_ddim_sampler(initial_samples, random.split(subkey_bayes, 100))
other_samples = other_ddim_sampler(initial_samples, random.split(subkey_other, 100))


In [None]:
fig, axes = plt.subplots(1, 2)

for ax in axes:
    ax.scatter(*dataset[:1000].T, color='blue', alpha=.6)
    # ax.quiver(X, Y, scores[:, 0], scores[:, 1], color='black')
    ax.set_xlim(-15, 15)
    ax.set_ylim(-15, 15)
axes[0].set_title('DDIM other')
axes[1].set_title('DDIM Bayes')
other_artists, bayes_artists, quiver_artists = [], [], []
# for ax in axes:
#     quiver_artists.append(ax.quiver(X, Y, score_quivers[0][:, 0], score_quivers[0][:, 1], color='black'))

other_artists.append(axes[0].scatter(*initial_samples.T, color='orange'))
bayes_artists.append(axes[1].scatter(*initial_samples.T, color='red'))
artists = other_artists + bayes_artists + quiver_artists

def init():

  return artists


def update(t):
    f1, f2 = t
    artists[0].set_offsets(other_samples[:, f1])
    artists[1].set_offsets(bayes_samples[:, f2])
    # artists[-2].set_UVC(score_quivers[frame][:, 0], score_quivers[frame][:, 1])
    # artists[-1].set_UVC(score_quivers[frame][:, 0], score_quivers[frame][:, 1])
    return artists

min_len = min(bayes_samples.shape[1], other_samples.shape[1])
t2 = jnp.trunc(jnp.linspace(0, bayes_samples.shape[1], min_len)).astype(int)
t1 = jnp.trunc(jnp.linspace(0, other_samples.shape[1], min_len)).astype(int)
matplotlib.animation.FuncAnimation(fig,  update, frames=zip(t1, t2),
                    init_func=init, blit=True)

In [None]:
from ot import max_sliced_wasserstein_distance
KEY, subkey_init = random.split(KEY, 2)
initial_samples = random.normal(subkey_init, shape=(2000, 2))*std_max
ddim_sws = {}
for eps in [0., 0.1, 0.5, 1.]:
    ddim_sws[eps] = {'N': [], 'sw': []}
    for N in jnp.arange(2, 11, 1)**2:
        KEY, subkey_sampler = random.split(KEY, 2)
        samples = make_ddim_sampler(N, eps, score_net=lambda x, std: model.apply(net_params, x, std[None]))(initial_samples, random.split(subkey_sampler, initial_samples.shape[0]))
        sw = max_sliced_wasserstein_distance(samples[:,-1], dataset, n_projections=1000)
        ddim_sws[eps]['N'].append(N)
        ddim_sws[eps]['sw'].append(sw)

In [None]:
fig, ax = plt.subplots(1, 1)
for c, (eps, eps_data) in zip(plt.get_cmap("rainbow")(jnp.arange(len(ddim_sws)) / (len(ddim_sws) - 1)), ddim_sws.items()):
    ax.plot(eps_data['N'], eps_data['sw'], color=c, marker='*', linestyle='dashed', label=eps)
ax.set_xscale('log')
ax.set_ylabel('Sliced Wasserstein')
ax.set_xlabel('N')
ax.legend()
fig

# Extra: Celeb dataset and Hugging Faces 

We now consider the dataset to be the Celeba HQ 256 and we use the pretrained model from Hugging Faces' diffusers library (https://github.com/huggingface/diffusers).

The main difference with the previous sections, is that, as in [4], they use the *Variance Preserving* framework, which corresponds to simply changing the perturbation kernel to:

$$ q_{t|0}(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\alpha_t} x_0, (1 - \alpha_t) \operatorname{I})\,,$$

where ${\alpha_t}_{t \in [\varepsilon, 1]}$ is such that $\alpha_t \in (0, 1)$ and $\lim_{t\rightarrow 1} \alpha_t = 0$. All the calculations above are feasible with this perturbation Kernel, except that the distribution $\lambda_n = \mathcal{N}(0, \operatorname{I})$.

In [None]:
from diffusers import DDIMPipeline

model_id = "google/ddpm-ema-celebahq-256"

pipe = DDIMPipeline.from_pretrained(model_id)

In [None]:
image = pipe(eta=1, num_inference_steps=10).images[0]
image

In [None]:
image = pipe(eta=0, num_inference_steps=10).images[0]
image