In [None]:
%pip install -q plotly
%pip install -q numpy
%pip install -q scipy
%pip install -q nbformat
%pip install -q pandas

import numpy as np
import pandas as pd
import plotly.express as px
from scipy import stats

from typing import Callable

In [2]:
def batch_norm_pdf(x: np.ndarray, means: np.ndarray, stddev: float):
  """Batched multivariate normal pdf with a fixed stddev.

  Args:
    x: Data points, B x D
    means: Means, C x D.
    stddev: A fixed stddev. Interpreted as the covariance matrix
      np.eye(D) * stddev.

  Returns: 
    A matrix B x C such that m_i_j returns the pdf of data point i
      under normal distribution j.
  """
  # TODO: This should ideally be pure numpy.
  assert len(x.shape) == len(means.shape) == 2
  r: list[np.ndarray] = []
  for m in means:
    r.append(stats.multivariate_normal.pdf(x, m, stddev))

  return np.stack(r, axis=1)


class DiffusionEvolution:
  """A run of diffusion evolution. (https://arxiv.org/abs/2410.02543)."""
  def __init__(
      self,
      candidates: np.ndarray,
      fitness_fun: Callable[[np.ndarray], np.ndarray],
      steps: int,
      sigma_scale: float):
    """Create a new diffusion evolution instance.

    Args:
      candidates: An array of N initial candidates, N x ...
      fitness_fun: Maps an array of candidates to an array of positive
        fitness values.
      steps: Total number of steps.
      sigma_scale: The scale of the noise term relative to maximum 
        sqrt(1 - alpha_{t-1}) in each step. Must be in (0, 1) (exclusive).
    """
    self.step_candidates = [candidates]
    assert len(candidates.shape) == 2
    assert len(candidates) > 1
    self.fitness_fun = fitness_fun
    self.steps = steps
    assert 0 < sigma_scale < 1.0
    self.sigma_scale = sigma_scale
    self.fitness = [fitness_fun(candidates)]
  
  @property
  def t(self) -> int:
    """Current time step."""
    return self.steps - len(self.step_candidates) + 1

  def alpha(self, t: int) -> float:
    """Diffusion schedule."""
    return 1.0 - (t / self.steps)

  def sigma(self, t: int) -> float:
    """Noise schedule."""
    return self.sigma_scale * np.sqrt(1 - self.alpha(t - 1))

  def _step(self) -> np.ndarray:
    """Step the algorithm and return new candidates."""
    candidates = self.step_candidates[-1]
    t = self.t
    alpha_t = self.alpha(t)
    # pdfs: N x N
    pdfs = batch_norm_pdf(
      candidates, candidates * np.sqrt(alpha_t), 1 - alpha_t)
    weights = self.fitness[-1][np.newaxis, :] * pdfs

    # Normalized weights per candidate.
    weights_sum = np.sum(weights, axis=-1, keepdims=True)

    # New N x D estimate for the origin point for each candidate.
    x0s = (weights / weights_sum) @ candidates

    alpha_tpred = self.alpha(self.t - 1)
    x0_term = np.sqrt(alpha_tpred) * x0s
    offset_term = np.sqrt(1 - alpha_tpred - self.sigma(t) ** 2) * (
      (candidates - np.sqrt(alpha_t) * x0s) / np.sqrt(1 - alpha_t))
    noise_term = np.random.normal(0, self.sigma(t), size=candidates.shape)
    new_candidates = x0_term + offset_term + noise_term
    self.step_candidates.append(new_candidates)
    self.fitness.append(self.fitness_fun(new_candidates))
    return new_candidates

  def diffuse(self) -> np.ndarray:
    """Run the algorithm until t = 0."""
    while self.t > 0:
      self._step()
    return self.step_candidates[-1]
  
  def run_data(self) -> pd.DataFrame:
    """Diffuse and return run data a dataframe."""
    self.diffuse()
    candidates = np.array(self.step_candidates)
    return pd.DataFrame.from_records(
      [{'frame': frame,
        'candidate': candidate_nr,
        'x': candidates[frame, candidate_nr, 0],
        'y': candidates[frame, candidate_nr, 1],
        'fitness': self.fitness[frame][candidate_nr]}
       for frame in range(candidates.shape[0])
       for candidate_nr in range(candidates.shape[1])])


In [None]:
candidates = np.random.uniform(-5, 5, size=(1000, 2))
def four_gaussians(x: np.ndarray) -> np.ndarray:
  """Mixture model of a grid of four gaussians with different scale."""
  return (
    stats.multivariate_normal.pdf(x, [-2.5, 2.5], 0.5) +
    stats.multivariate_normal.pdf(x, [2.5, 2.5], 1.0) +
    stats.multivariate_normal.pdf(x, [-2.5, -2.5], 2.0) +
    stats.multivariate_normal.pdf(x, [2.5, -2.5], 4.0)) / 4.0

diff_evo = DiffusionEvolution(candidates, four_gaussians, 50, 0.25)

def plot(diff_evo: DiffusionEvolution):
  df = diff_evo.run_data()
  # Ensure we have a minimum size.
  df['size'] = df['fitness'].apply(lambda x: max(0.01, x))
  # Create heatmap data
  x_vals = np.linspace(-5, 5, 100)
  y_vals = np.linspace(-5, 5, 100)
  X, Y = np.meshgrid(x_vals, y_vals)
  Z = diff_evo.fitness_fun(np.stack([X, Y], axis=-1))
  fig = px.scatter(
    df, x='x', y='y', size='size', 
    size_max=15, 
    color='fitness', animation_frame='frame',
    range_x=[-5, 5], range_y=[-5, 5])
  fig.add_heatmap(
    x=x_vals, y=y_vals, z=Z, colorscale='Viridis', showscale=False)

  fig.update_layout(
    height=800,
    width=800,
    title='Diffusion Evolution',
  )
  fig.show()

  df['mean_fitness'] = df.groupby('frame')['fitness'].transform('mean')
  df['max_fitness'] = df.groupby('frame')['fitness'].transform('max')
  fitness_fig = px.line(
    df, x='frame', y='mean_fitness',
    title='Mean fitness over time')
  fitness_fig.update_layout(
    height=400,
    width=800)
  fitness_fig.show()

plot(diff_evo)
