In [1]:
import jax.numpy as jnp
import jax
import plotly.graph_objects as go

In [11]:
from stanza.diffusion import DDPMSchedule

def make_plot(dim=1024, bkw=True):
    a = -jnp.ones((dim,))
    a = a #/ jnp.linalg.norm(a)
    b = jnp.ones((dim,))
    b = b # / jnp.linalg.norm(b)

    def generate(rng_key: jax.Array) -> jax.Array:
        t = jax.random.choice(rng_key, 2)
        return jax.lax.cond(t == 0, lambda: a, lambda: b)
    train_data = jax.vmap(generate)(jax.random.split(jax.random.PRNGKey(0), 100))
    schedule = DDPMSchedule.make_linear(100)
    # The schedule also has the compute_denoised method
    # which given a (small) dataset will compute the ground-truth model output!
    # This is useful for debugging whether the denoiser has sufficient
    # capacity to capture the true score function.
    def gt_denoiser(_rng_key, x, t):
        denoised = schedule.compute_denoised(x, t, train_data)
        return schedule.output_from_denoised(x, t, denoised)

    def sample_bkw_traj(rng_key):
        ts = schedule.num_steps - jnp.arange(schedule.num_steps + 1)
        return ts, schedule.sample(rng_key, gt_denoiser, train_data[0], trajectory=True)[1]
    
    def sample_fwd_traj(rng_key):
        r_s, r_t = jax.random.split(rng_key)
        i = jax.random.choice(r_s, len(train_data))
        ts = jnp.arange(schedule.num_steps + 1)
        return ts, schedule.forward_trajectory(r_t, train_data[i])[0]

    sample_traj = sample_bkw_traj if bkw else sample_fwd_traj
    proj_dir = jnp.ones((dim,))
    proj_dir = jnp.zeros((dim,))
    proj_dir = proj_dir.at[0].set(1)
    proj_dir = proj_dir / jnp.linalg.norm(proj_dir)

    ts, trajectories = jax.vmap(sample_traj)(jax.random.split(jax.random.PRNGKey(42), 256))
    denoised_traj = jax.vmap(jax.vmap(schedule.compute_denoised, in_axes=(0, 0, None)),
                             in_axes=(0,0,None))(trajectories, ts, train_data)
    ts = ts.reshape(-1)
    xs = trajectories.reshape((-1, dim))
    zs = denoised_traj.reshape((-1, dim))
    x_projs = jax.vmap(lambda x: jnp.dot(x, proj_dir))(xs)
    z_projs = jax.vmap(lambda x: jnp.dot(x, proj_dir))(zs)
    return go.Figure([
        go.Scatter(x=ts, y=x_projs, mode='markers', opacity=0.5),
        go.Scatter(x=ts, y=z_projs, mode='markers', opacity=0.5)
    ])

In [12]:
make_plot(256, True)

In [13]:
make_plot(1, True)