In [None]:
import plotly

In [None]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import scipy

import jax
from jax import numpy as jnp


variance_preserving = lambda logsnr, t: ((1 + jnp.exp(-logsnr)) ** -0.5, (1 + jnp.exp(logsnr)) ** -0.5)
subvariance = lambda lognsr, t: (1. - t + 1e-9, t)

def interp_alpha_sigma(interp):
  def alpha_sigma(logsnr, t):
    alpha1, sigma1 = variance_preserving(logsnr, t)
    alpha2, sigma2 = subvariance(logsnr, t)
    alpha = interp * alpha1 + (1. - interp) * alpha2
    sigma = interp * sigma1 + (1. - interp) * sigma2
    return alpha, sigma
  return alpha_sigma

def make_plot(x_mean=0., x_var=0.001, alpha_sigmas=(variance_preserving,), title=""):
  int_steps = 6
  n_rows = len(alpha_sigmas)
  # Create figure with subplots
  fig = go.Figure()
  fig = make_subplots(rows=n_rows, cols=3,
                      specs=[[{"type": "scatter"}, {"type": "scatter"}, {"type": "scatter"}],] * n_rows,
                      subplot_titles=['noise', '', 'data'],
                      column_widths=[0.15, 0.7, 0.15],
                      shared_xaxes=False, shared_yaxes=True,
                      horizontal_spacing=0.02, vertical_spacing=0.05)

  ys = jnp.linspace(-3., 3., 100)
  eps_dist = scipy.stats.norm.pdf(ys) / scipy.stats.norm.pdf(0.)
  x_norm = scipy.stats.norm.pdf(0., loc=0., scale=x_var ** 0.5) * 1.5
  x_dist = scipy.stats.norm.pdf(ys, loc=0., scale=x_var ** 0.5) / x_norm


  # Add marginal distributions
  for row in range(1, 1 + n_rows):
    fig.add_trace(go.Scatter(x=eps_dist, y=ys, line=dict(color='black', width=1.), mode='lines', showlegend=False, hoverinfo='skip'), row=row, col=1)
    fig.add_trace(go.Scatter(x=x_dist, y=ys, line=dict(color='black', width=1.), mode='lines', showlegend=False, hoverinfo='skip'), row=row, col=3)


  start_points = jnp.linspace(-2., 2., num=6)
  
  schedule = lambda t: -2. * (jnp.log(t + 1e-6) - jnp.log(1. + 1e-6 - t))
  
  for row, alpha_sigma in zip(range(1, 1 + n_rows), alpha_sigmas):
    tts = jnp.linspace(0., 1., num=256)[::-1]
    ts = tts[:-1]
    ss = tts[1:]
    z = start_points
    zs = [start_points]
    for t, s in zip(ts, ss):
      logsnr_t = schedule(t)
      logsnr_s = schedule(s)
      alpha_t, sigma_t = alpha_sigma(logsnr_t, t)
      alpha_s, sigma_s = alpha_sigma(logsnr_s, s)
      beta2_t = (1 + jnp.exp(-logsnr_t - jnp.log(x_var))) ** -1

      pred_x = beta2_t * (z / alpha_t - x_mean) + x_mean
      pred_eps = (z - alpha_t * pred_x) / sigma_t

      z_s = alpha_s * pred_x + sigma_s * pred_eps
      zs.append(z_s)
      z = z_s

    zs = jnp.stack(zs)

    for i in range(zs.shape[1]):
      fig.add_trace(go.Scatter(x=1. - tts, y=zs[:, i], mode='lines', opacity=0.5, line=dict(color='black', width=1.), showlegend=i == 0 and row == 1, name="Ground truth", legendgroup='ground-truth', hoverinfo='skip'), row=row, col=2)

    tts = jnp.linspace(0., 1., num=int_steps)[::-1]
    ts = tts[:-1]
    ss = tts[1:]
    z = start_points
    zs = [start_points]

    z_mean = 0.
    z_var = 1.

    for t, s in zip(ts, ss):
      logsnr_t = schedule(t)
      logsnr_s = schedule(s)
      alpha_t, sigma_t = alpha_sigma(logsnr_t, t)
      alpha_s, sigma_s = alpha_sigma(logsnr_s, s)
      beta2_t = (1 + jnp.exp(-logsnr_t - jnp.log(x_var))) ** -1
      pred_x = beta2_t * (z / alpha_t - x_mean) + x_mean
      pred_eps = (z - alpha_t * pred_x) / sigma_t 
      z_s = alpha_s * pred_x + sigma_s * pred_eps

      pred_x = beta2_t * (z_mean / alpha_t - x_mean) + x_mean
      pred_eps = (z_mean - alpha_t * pred_x) / sigma_t 
      z_mean = alpha_s * pred_x + sigma_s * pred_eps
      z_var = (alpha_s * beta2_t / alpha_t + sigma_s / sigma_t * (1. - beta2_t)) ** 2. * z_var
      zs.append(z_s)
      z = z_s
    z0_dist = scipy.stats.norm.pdf(ys, loc=z_mean, scale=z_var ** 0.5) / x_norm
    fig.add_trace(go.Scatter(x=z0_dist, y=ys, mode='lines', line=dict(color='#1f77b4', width=2.), showlegend=False, legendgroup='ddim', hoverinfo='skip'), row=row, col=3)

    zs = jnp.stack(zs)

    for i in range(zs.shape[1]):
      fig.add_trace(go.Scatter(x=1. - tts, y=zs[:, i], mode='lines', line=dict(color='#1f77b4', width=2.), showlegend=i == 0 and row == 1, name='6-step DDIM', legendgroup='ddim', hoverinfo='skip'), row=row, col=2)

    tts = jnp.linspace(0., 1., num=int_steps)[::-1]
    ts = tts[:-1]
    ss = tts[1:]
    z = start_points
    zs = [start_points]
    z_mean = 0.
    z_var = 1.

    for t, s in zip(ts, ss):
      logsnr_t = schedule(t)
      logsnr_s = schedule(s)
      (alpha_t, sigma_t), (alpha_dt, sigma_dt) = jax.jvp(lambda t: alpha_sigma(schedule(t), t), (t,), (jnp.ones_like(t),))
      alpha_s, sigma_s = alpha_sigma(logsnr_s, s)
      beta2_t = (1 + jnp.exp(-logsnr_t - jnp.log(x_var))) ** -1
      pred_x = beta2_t * (z / alpha_t - x_mean) + x_mean
      pred_eps = (z - alpha_t * pred_x) / sigma_t
      z_s = z + (alpha_dt * pred_x + sigma_dt * pred_eps) * (s - t)
      pred_x = beta2_t * (z_mean / alpha_t - x_mean) + x_mean
      pred_eps = (z_mean - alpha_t * pred_x) / sigma_t 
      z_mean = z_mean + (alpha_dt * pred_x + sigma_dt * pred_eps) * (s - t)
      z_var = (1. + (alpha_dt * beta2_t / alpha_t + sigma_dt / sigma_t * (1. - beta2_t)) * (s - t)) ** 2. * z_var
      zs.append(z_s)
      z = z_s

    zs = jnp.stack(zs)

    z0_dist = scipy.stats.norm.pdf(ys, loc=z_mean, scale=z_var ** 0.5) / x_norm
    print(z_mean, z_var)
    fig.add_trace(go.Scatter(x=z0_dist, y=ys, mode='lines', line=dict(color='#FF7F0e', width=2.), showlegend=False, legendgroup='euler', hoverinfo='skip'), row=row, col=3)

    for i in range(zs.shape[1]):
      fig.add_trace(go.Scatter(x=1. - tts, y=zs[:, i], mode='lines', line=dict(color='#FF7F0e', width=2.), showlegend=i == 0 and row == 1, name='6-step Euler', legendgroup='euler', hoverinfo='skip'), row=row, col=2)


  # Update layout
  fig.update_layout(
      title=title,
      template='simple_white',
      showlegend=True,
      yaxis=dict(range=[-2.2, 2.2], fixedrange=True),
      yaxis4=dict(range=[-2.2, 2.2], fixedrange=True),
      xaxis1=dict(visible=False, range=[0., 1.]),
      xaxis2=dict(visible=False),
      xaxis3=dict(visible=False, range=[0., 1.]),
      xaxis4=dict(visible=False, range=[0., 1.]),
      xaxis5=dict(visible=False),
      xaxis6=dict(visible=False, range=[0., 1.]),
      
  )
  return fig
title = 'Flow Matching and Variance Preserving sampling paths'

fig = make_plot(x_mean=0., x_var=0.25, alpha_sigmas=[subvariance], title=title)

x_vars = np.linspace(0., 1., 8)
plots = [make_plot(x_mean=0., x_var=0.25, alpha_sigmas=[interp_alpha_sigma(x_var)], title=title) for x_var in x_vars]


sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "Parameterization: ",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 300, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": []
}

for i, x_var in enumerate(x_vars):
    i = str(i)
    step_dict = {"args": [
        [i],
        {"frame": {"duration": 30, "redraw": False},
         "mode": "immediate",
         "transition": {"duration": 30}}
    ],
        "label": '',
        "method": "animate"}
    sliders_dict['steps'].append(step_dict)
sliders_dict['steps'][0]['label'] = 'FM'
sliders_dict['steps'][-1]['label'] = 'VP'

fig.update(frames=[go.Frame(
                        name=str(i),
                        data=plot.data,
                        traces=list(range(len(fig.data)))) # fig.data[1] is updated by each frame
        for i, plot in enumerate(plots)])

fig.update_layout(sliders=[sliders_dict])

fig.show()
print("Straight to a point is not the same as straight between distributions")


fig.write_html('../assets/html/2025-04-28-distill-example/interactive_alpha_sigma.html', include_plotlyjs="cdn")

0.0 0.13634641
0.0 0.13634641
0.0 0.14877477
0.0 0.16092104
0.0 0.17275429
0.0 0.18426582
0.0 0.19545211
0.0 0.2063134
0.0 0.21685322


Straight to a point is not the same as straight between distributions


In [None]:



title = 'Variance Preserving vs Flow Matching paths for varying data distributions'


x_vars = np.array([0.01, 0.05, 0.1, 0.5, 1.])
fig = make_plot(x_mean=0., x_var=x_vars[0], alpha_sigmas=[variance_preserving, subvariance], title=title)

plots = [make_plot(x_mean=0., x_var=x_var, alpha_sigmas=[variance_preserving, subvariance], title=title) for x_var in x_vars]


sliders_dict = {
    "active": 0,
    "yanchor": "top",
    "xanchor": "left",
    "currentvalue": {
        "font": {"size": 20},
        "prefix": "data variance: ",
        "visible": True,
        "xanchor": "right"
    },
    "transition": {"duration": 300, "easing": "cubic-in-out"},
    "pad": {"b": 10, "t": 50},
    "len": 0.9,
    "x": 0.1,
    "y": 0,
    "steps": []
}

for i, x_var in enumerate(x_vars):
    i = str(i)
    step_dict = {"args": [
        [i],
        {"frame": {"duration": 30, "redraw": False},
         "mode": "immediate",
         "transition": {"duration": 30}}
    ],
        "label": str(x_var),
        "method": "animate"}
    sliders_dict['steps'].append(step_dict)

fig.update(frames=[go.Frame(
                        name=str(i),
                        data=plot.data,
                        traces=list(range(len(fig.data)))) # fig.data[1] is updated by each frame
        for i, plot in enumerate(plots)])

fig.update_layout(sliders=[sliders_dict])
fig.update_layout(
    yaxis1=dict(title="VP Diffusion"),
    yaxis4=dict(title="Flow Matching"),
    height=600,)


fig.show()

fig.write_html('../assets/html/2025-04-28-distill-example/interactive_vp_vs_flow.html', include_plotlyjs="cdn")

0.0 1.5369426e-06
0.0 0.0012947327
0.0 1.5369426e-06
0.0 0.0012947327
0.0 0.018576652
0.0 0.017971894
0.0 0.06178116
0.0 0.04572126
0.0 0.48616976
0.0 0.29090664
0.0 0.9999999
0.0 0.5939408
