## Memorization Tutorial

In this colab we seek to showcase an example of memorization, in which a diffusion model is only able to memorize the training dataset, while failing the learn the underlying distribution.

## Recap: How Score-based generative models work?
Given access to samples $\{x_i\}$ from a target distribution with *unknown* probability distribution $p_{\text{data}}$, generative models seek to learn the underlying distribution. The distribution in learn implicitly by learning a sampler, which draws samples from $p_{\text{data}}$.

Here we will use a simple low-dimensional distribution, and we will try to learn it from some of its samples, using a *naive*, albeit intuitive, approach to learn the score function.



### Forward SDE
Score-based generative models (SGM) are able to draw samples from the underlying distribution by leveraging the properties of a diffusive stochastic differential equation, which transform any distribution to a Gaussian distribution. Such equation is usually called the **forward SDE**, which then we will seek to **solve backwards**, thus effectively transforming samples from a Gaussian distribution to samples of our target one.

Under a few regularity assumptions, any SDE of the form
$$\begin{array}{rcl}
        \mathrm{d}X_t & = & f(X_t,t)\mathrm{d}t + g(t)\mathrm{d}W_t \\
        X_0 & \sim & p_{\text{data}}.
\end{array}
$$
with *drift coefficient* $f$ and *diffusion coefficient* $g$, has a corresponding Fokker-Planck equation given by
$$\partial_{t}p(x,t) = \nabla_{x}\cdot(f(x,t)p(x,t)) + \frac{1}{2}g(t)^{2}\Delta_{x} p(x,t).$$

From now on denote the marginal distribution $p_{t}(x) := p(x,t)$ and we note that $p_{0} = \mu_{\text{data}}$. We suppose that we solve this SDE until a terminal time $T=1$.

### Reverse SDE
Fortunately, the previous SDE has the following **reverse-time SDE**
$$\begin{array}{rcl}
        \mathrm{d}X_t & = & -[f(X_t,t) + g(t)^2\nabla \log{p(X_t, t)}]\mathrm{d}t + g(t)\mathrm{d} \bar{W}_t,\\
        Y_1 & \sim & N(0, \sigma^2_{\text{max}}),
\end{array}$$
where,  $\nabla_x \log p(x, t)$ is called the *score function*, and $\bar{W}_t$ is a backward Wiener process.

Thus, if we know $p(x, t)$ then we could run the reverse SDE and generate samples from $X_0 \sim p_{\text{data}}$.

### Computing the score function.

Since $p_0 = p_{\text{data}}$ is unknown, we do not have access to the marginals $p_t$ of the forward SDE. From the expression above, having access to such marginals (or an estimate of them) is a pre-requisite for solving the SDE.

One simple way of estimating $\nabla_x \log p(x, t)$ is to use the empirial distribution given by dirac deltas centered at each datapoint, and then use the properties of the SDE to compute the marginals induced by evolving the SDE forward in time using the empirical measure as initial condition.

This provides a very simple formula for the score function, but it leads to the phenomenon of **memorization** which we seek to showcase in this colab.

### Imports

In [None]:
from functools import partial
from typing import Callable, Union
import jax
import jax.numpy as jnp
import jax.random as random
import matplotlib.pyplot as plt

### Samples from $p_{\text{data}}$.

For simplicity we consider a simple 2-dimensional distribution: a uniform distribution supported on a circle. We will draw a different amount of samples from this distribution.

In [None]:
def sample_circle(n_samples: int) -> jax.Array:
  """ Generates N samples from a 2d circle.

  Args:
    n_samples: Number of samples to be generated.

  Returns:
    A jax.Array of shape (n_samples, 2) containing the samples.
  """
  alphas = jnp.linspace(0, 2*jnp.pi * (1 - 1/n_samples), n_samples)
  xs = jnp.cos(alphas)
  ys = jnp.sin(alphas)
  mf = jnp.stack([xs, ys], axis=1)
  return mf

We generate 8 samples and we plot them.

In [None]:
n_samples = 8
mf = sample_circle(n_samples)
plt.figure(figsize=(4, 4))
plt.scatter(mf[:, 0], mf[:, 1])
plt.axis('equal')
plt.title(r'Samples from $p_{\text{data}}$')
plt.show()

### Defining the Noising Process

For simplicity we will use specific choices of $f$ and $g$. Namely, we define a function
$$\beta(t) = \beta_\text{min} + t(\beta_\text{max} - \beta_\text{min}),$$
for $\beta_\text{min} = 0.001, \beta_\text{max} = 3$.

We consider then
$$f(t) := -\frac{1}{2} \beta(t) \qquad \text{and} \qquad  g(t) := \sqrt{\beta(t)}$$

Now, following the computation for the equivalent description of the noising process using a rescaling factor $s(t)$ and a noising factor $\sigma(t)$, we have that
$$s(t) = \exp\left ( - \frac{1}{2}  \int_{0}^t \beta(s) ds \right),$$
and
$$\sigma^2(t) = \int_0^t \frac{g^2(\xi)}{s^2(\xi)} d\xi = \int_0^t \frac{\beta(\xi)}{s(\xi)} d\xi = \int_0^t \beta(\xi)\exp\left (\int_{0}^{\xi} \beta(s) ds \right) d\xi = 1 - \exp\left ( - \int_{0}^{t} \beta(s) ds \right).$$

We define these function in what follows.

In [None]:
beta_min = 0.001
beta_max = 3

def beta(t: jax.Array) -> jax.Array:
  return beta_min + t*(beta_max - beta_min)

def int_beta(t: jax.Array) -> jax.Array:
  """Integral of beta from 0 to t."""
  return t*beta_min + 0.5 * t**2 * (beta_max - beta_min)

def f(x: jax.Array, t: jax.Array) -> jax.Array:
  return -0.5*beta(t)*x

def g(t: jax.Array) -> jax.Array:
  return jnp.sqrt(beta(t))

def s(t: jax.Array) -> jax.Array:
  return jnp.exp(-0.5 * int_beta(t))

def sigma2(t: jax.Array) -> jax.Array:
  return 1 - jnp.exp(-int_beta(t))

### Computing the marginal of the forward SDE.

We have seen above that the time-$t$ transition kernel is given as
$$p(x_t | x_0, t) = N(s(t) x_0, \sigma(t)^2 I).$$

We now assume that we have $N$ samples $\{x^i\}_{i=1}^N$ from our target distribution $p_{\text{data}}$.

The empirical measure
$$p^N_0(x) = \frac{1}{N}\sum_{i =0}^{N} \delta_{x_i}(x).$$
is then an approximation to $p_{\text{data}}$. If we start the forward SDE in $p_0 \approx p^N_0$, we get marginals $p^{N}(x, t)$ given by,

$$p^N(x, t) = \frac{1}{N}\sum_{i =0}^{N} N(x; s(t)x_i, \sigma^2(t)I).,$$

which is nothing more than a Gaussian mixture with $N$ components, one for each sample $x_i$. Each component of the mixture, is centred at $s(t) x_i$ and have variance $\sigma^2(t)$.

Therefore we can also actually write down the empirical score function $\nabla \log p^{N}_t$ (all though every evaluation of it needs to access the whole training set!).

In [None]:
from jax.scipy.special import logsumexp

def build_log_p(x_0: jax.Array) -> Callable[[jax.Array, jax.Array], jax.Array]:
  """ Builds the log density log p^N(x, t) using the empirical distribution.
  Args:

  x_0: Samples from the target distribution.

  Reurns:
    The log density log p^N(x, t) as described above.

  """
  N = x_0.shape[0]
  def log_p(x: jax.Array,  t: jax.Array) -> jax.Array:
    means = x_0 * s(t)
    v = sigma2(t)
    potentials = jnp.sum(-(x - means)**2 / (2 * v), axis=1)
    return logsumexp(potentials, axis=0, b=1/N)
    # this is equivalent to
    # return jnp.log(1/N * jnp.sum(jnp.exp(potentials)))
    # but is numerically more stable
  return log_p

log_p = build_log_p(mf)

nabla_log_hat_pt = jax.jit(
    jax.vmap(jax.grad(log_p), in_axes=(0, 0), out_axes=(0))
    )

Following our intuition from the Langevin dynamics. The score function should be pointing towards the support of the distribution. We visualize this on what follows.

In [None]:
def plot_score(
    score: Callable[[jax.Array], jax.Array],
    t: jax.Array,
    area_min: float=-1.,
    area_max: float=1.,
    data_samples: jax.Array | None = None
) -> None:
  """
  Plots the score function and optionally overlays data samples.

  Args:
    score: A callable function that takes a position and a time, and returns
      the score evaluated at that position.
    t: The time value at which to evaluate the score function.
    area_min: The minimum value for the x and y axes of the plot.
    area_max: The maximum value for the x and y axes of the plot.
    data_samples: Optional jax.Array of data samples to plot.
  """
  @partial(jax.jit, static_argnums=[0,])
  def _helper(
    score: Callable[[jax.Array, jax.Array], jax.Array],
    t: jax.Array,
    area_min: float,
    area_max: float,
  ) -> tuple[jax.Array, jax.Array]:
    x = jnp.linspace(area_min, area_max, 32)
    x, y = jnp.meshgrid(x, x)
    grid = jnp.stack([x.flatten(), y.flatten()], axis=1)
    t = jnp.ones((grid.shape[0], 1)) * t
    scores = score(grid, t)
    return grid, scores

  grid, scores = _helper(score, t, area_min, area_max)

  plt.figure(figsize=(6, 6))
  plt.quiver(
    grid[:, 0],
    grid[:, 1],
    scores[:, 0],
    scores[:, 1],
    label=r"$\nabla_x \log p^N$",
  )
  plt.axis("equal")

  if data_samples is not None:  # To add the extra if necessary.
    plt.scatter(
      data_samples[:, 0],
      data_samples[:, 1],
      label=r"Samples from $p_{\text{data}}$",
    )

  plt.legend()
  plt.show()


plot_score(nabla_log_hat_pt, 0.005, -1.5, 1.5, data_samples=mf)

Here we can observe where the main issue of this approach will arise: the flow points towards the support of the empirical distribution, and **not** the underlying one.

To further visualize this problem, we consider the sampler that we built in the adjacent [notebook](https://github.com/google-research/swirl-dynamics/blob/main/swirl_dynamics/projects/probabilistic_diffusion/colabs/tutorial/diffusion_tutorial.ipynb) by solving the SDE backwards in time.

In [None]:
def sde_solver_backwards(
    key: jax.Array,
    grad_log: Callable[[jax.Array, jax.Array], jax.Array],
    g: Callable[[jax.Array], jax.Array],
    f: Callable[[jax.Array, jax.Array], jax.Array],
    dim: int,
    n_samples: int,
    num_time_steps: int = 100,
) -> jax.Array:
    """Euler-Maruyama solver for the backward SDE.

    Args:
        key: Seed for the random number generator.
        grad_log: Drift term for the SDE (the score function).
        g: Diffusion term for the SDE.
        f: Drift term for the SDE.
        dim: Dimension of the problem.
        n_samples: Number of samples.
        num_time_steps: Number of time steps.

    Returns:
        A tuple containing the initial condition (x_1) and the sampled values.
    """

    ts = jnp.linspace(1 / (num_time_steps - 1), 1, num_time_steps)
    delta_t = ts[1:] - ts[:-1]

    def time_step(
          carry: tuple[jax.Array, jax.Array],
          params_time: tuple[jax.Array, jax.Array]
        )-> tuple[tuple[jax.Array, jax.Array], None]:
        """Performs one step of the Euler-Maruyama."""
        key, x = carry
        t, dt = params_time
        key, subkey = random.split(key)

        # Euler-Maruyama step
        diff = g(1 - t)
        t_broadcasted = jnp.ones((x.shape[0], 1)) * t
        drift = -f(x, 1 - t_broadcasted) + grad_log(x, 1 - t_broadcasted) * diff**2
        noise = random.normal(subkey, shape=x.shape)
        x = x + dt * drift + jnp.sqrt(dt) * diff * noise
        return (key, x), None  # We don't need to collect intermediate x

    key, subkey = random.split(key)
    sigma2_1 = sigma2(1.0)
    x_1 = jnp.sqrt(sigma2_1) * random.normal(subkey, shape=(n_samples, dim))

    carry = (key, x_1)
    (_, samples), _ = jax.lax.scan(time_step, carry, jnp.stack([ts[:-1], delta_t], axis=1))
    return x_1, samples

Now we can use this solution to sample from the underlying distributions using this approximation of the score function.

In [None]:
rng, step_rng = random.split(jax.random.PRNGKey(0))
x_1, gen_samples = sde_solver_backwards(step_rng, nabla_log_hat_pt, g, f, 2, 5000)

In [None]:
def plot_heatmap(positions: jax.Array,
                 area_min: float = -2.,
                 area_max: float = 2.) -> None:
  r"""Builds and plots a heatmap of the target distribution.

  Args:
      positions: Locations of all particles in $\mathbb{R}^2$, array (N, 2).
      area_min: Lowest x and y coordinates.
      area_max: Highest x and y coordinates.

  Returns:
      None, but it will plot a heatmap of all particles in the area
      [area_min, area_max] x [area_min, area_max].
  """
  if area_min >= area_max:
    raise ValueError("area_min should be strictly lower than area_max.")

  @jax.jit
  def produce_heatmap(
    positions: jax.Array, area_min: float, area_max: float
  )-> jax.Array:
    """Generates the heatmap data from particle positions."""
    # Define the grid for the heatmap.
    grid = jnp.linspace(area_min, area_max, 512)
    x, y = jnp.meshgrid(grid, grid)

    # Vectorized computation of distances.
    x_pos = positions[:, 0]
    y_pos = positions[:, 1]
    dist = (x - x_pos[:, None, None])**2 + (y - y_pos[:, None, None])**2

    # Vectorized computation of the heatmap contribution.
    heatmap_values = jnp.exp(-350 * dist)

    # Sum the contributions from all particles.
    return jnp.sum(heatmap_values, axis=0)

  # Generate the heatmap data.
  heatmap_data = produce_heatmap(positions, area_min, area_max)

  # Plot the heatmap.
  extent = [area_min, area_max, area_min, area_max]  # Corrected extent
  plt.imshow(
    heatmap_data, cmap="coolwarm", interpolation='nearest', extent=extent
    )

  # Invert the y-axis for proper orientation.
  ax = plt.gca()
  ax.invert_yaxis()

  # Add labels and title.
  plt.xlabel("X Coordinate")
  plt.ylabel("Y Coordinate")
  plt.title("Particle Heatmap")

  # Show the plot
  plt.show()

First we look a the samples we used as initial (or terminal) conditions for the SDE. We can observe that follows a centered Gaussian.

In [None]:
plot_heatmap(x_1, -2, 3)

However, the samples resulting from solving the SDE concentrate around the points already known. As a result, the sampler returns points from the training set, thus **memorizing** it instead of learning the underlying distribution

In [None]:
plot_heatmap(gen_samples, -3, 3)