In [1]:
import stanza.runtime
stanza.runtime.setup() # setup logging, etc.

import jax
import jax.flatten_util
import plotly.graph_objects as go
import plotly.subplots
import jax.numpy as jnp
import matplotlib.pyplot as plt

from stanza.data import PyTreeData
from stanza.data.normalizer import StdNormalizer
from stanza.diffusion import DDPMSchedule
from stanza.random import PRNGSequence
from stanza import partial

In [101]:
# Generate some data
def generate(rng_key: jax.Array) -> jax.Array:
    t = jax.random.uniform(rng_key, ())
    r = 0.5*t
    x = jnp.cos(10*t)*r
    y = jnp.sin(10*t)*r
    return jnp.stack((x, y))

train_data = jax.vmap(generate)(jax.random.split(jax.random.PRNGKey(42), 64))
train_data = PyTreeData(train_data)
test_data = jax.vmap(generate)(jax.random.split(jax.random.PRNGKey(43), 64))
test_data = PyTreeData(test_data)

normalizer = StdNormalizer.from_data(train_data, component_wise=False)
train_data_normalized = jax.vmap(normalizer.normalize)(train_data.tree)

def plot_samples(samples, data_samples=None):
    fig = go.Figure()
    if data_samples is not None:
        x = data_samples[:,0]
        y = data_samples[:,1]
        fig.add_trace(go.Scatter(
            x=x, y=y, mode='markers', 
            text=[f"{i}" for i in range(len(data_samples))],
            marker=dict(color='red', size=4), opacity=0.5, showlegend=False
        ))
    x = samples[:,0]
    y = samples[:,1]
    fig.add_trace(go.Scatter(
        x=x, y=y, mode='markers',
        marker=dict(color='blue', size=4), opacity=0.5,
        showlegend=False
    ))
    fig.update_layout(width=800, height=800, 
                    xaxis=dict(range=[-0.5, 0.5]), yaxis=dict(range=[-0.5, 0.5]),
                    margin=dict(l=0, r=0, b=0, t=0, pad=0))
    return fig

plot_samples(train_data.tree)
@partial(jax.jit, static_argnames=("schedule_type", "prediction_type", "T", "trajectory"))
def sample(schedule_type="cos", prediction_type="epsilon", T=16, time_offset=0, offset=None, rng=None, trajectory=False):
    if schedule_type == "cos":
        schedule = DDPMSchedule.make_squaredcos_cap_v2(T, offset=offset, prediction_type=prediction_type)
    elif schedule_type == "linear":
        schedule = DDPMSchedule.make_linear(T, prediction_type=prediction_type)
    elif schedule_type == "scaled_linear":
        # offset = 0. if offset is None else offset
        # schedule = DDPMSchedule.make_linear(64, beta_start=0.0001+offset, beta_end=0.02+offset, prediction_type=prediction_type)
        schedule = DDPMSchedule.make_squaredcos_cap_v2(T, offset=offset, prediction_type=prediction_type)
        # rescale so that we are discretizing the same SDE
        schedule = DDPMSchedule.make_rescaled(T, schedule, prediction_type=prediction_type)

    def gt_denoiser(_rng_key, x, t):
        t = jnp.maximum(t, time_offset)
        denoised = schedule.compute_denoised(x, t, train_data_normalized)
        output = schedule.output_from_denoised(x, t, denoised)
        return output

    def sample_gt(rng_key):
        if trajectory:
            sample, traj = schedule.sample(rng_key, gt_denoiser, train_data[0], trajectory=trajectory)
            return normalizer.unnormalize(sample), jax.vmap(normalizer.unnormalize)(traj)
        else:
            sample = schedule.sample(rng_key, gt_denoiser, train_data[0], trajectory=trajectory)
            return normalizer.unnormalize(sample)

    return jax.vmap(sample_gt)(jax.random.split(rng, 1024))

plot_samples(train_data.tree)

In [102]:
types = ["cos", "linear"]
Ts = [8, 16, 64, 128, 256]
rng = PRNGSequence(42)
samples = {(schedule_type, T): sample(schedule_type=schedule_type, T=T, rng=next(rng))
           for T in Ts for schedule_type in types}


fig = plotly.subplots.make_subplots(rows=len(types), cols=len(Ts),
        column_titles=[f"T={T}" for T in Ts], row_titles=[f"{schedule_type}" for schedule_type in types],
        vertical_spacing=0.1, horizontal_spacing=0.05
    )
for (schedule_type, T), s in samples.items():
    fig.add_traces(
        plot_samples(train_data.tree,s).data,
        rows=types.index(schedule_type)+1, cols=Ts.index(T)+1
    )
fig.update_layout(margin=dict(l=40,r=40,t=40,b=40))
fig


In [103]:
rng = PRNGSequence(42)
_, trajectory = sample(
    schedule_type="cos", T=32,
    offset=0.01, prediction_type="sample",
    time_offset=5, rng=next(rng),
    trajectory=True
)

ts = [0, 1, 8, 16, 32]
fig = plotly.subplots.make_subplots(rows=1, cols=len(ts),
        column_titles=[f"t={t}" for t in ts],
        vertical_spacing=0.1, horizontal_spacing=0.05
    )

for (i, t) in enumerate(ts):
    samples = trajectory[:,t]
    fig.add_traces(
        plot_samples(train_data.tree, samples).data, rows=1, cols=i+1
    )
fig.update_layout(margin=dict(l=40,r=40,t=40,b=40))
fig.show()

fig = go.Figure(layout=dict(width=800, height=800))
fig.update_layout(margin=dict(l=40,r=40,t=40,b=40))
fig.add_traces(
    plot_samples(train_data.tree).data
)
T = jnp.arange(trajectory.shape[1])/(trajectory.shape[1] - 1)
from plotly.colors import sequential
for i, color_map in [(872, [(0,"black"), (0.2,"red"), (1., "red")]), (925, [(0,"black"), (0.2,"green"), (1., "green")])]:
    fig.add_trace(
        go.Scatter(x=trajectory[i,:,0], y=trajectory[i,:,1], showlegend=False,
            mode='markers', marker=dict(color=T, size=5, colorscale=color_map))
    )
fig.show()


In [104]:
Ts = [8, 16, 32, 64, 128]
offsets = [5e-2, 1e-2, 1e-3, 5e-4]
rng = PRNGSequence(42)
samples = {(T, off): sample(schedule_type="scaled_linear", T=T, offset=off, rng=next(rng))
           for off in offsets for T in Ts}


fig = plotly.subplots.make_subplots(rows=len(offsets), cols=len(Ts),
        column_titles=[f"T={T}" for T in Ts], row_titles=[f"offset={off}" for off in offsets],
        vertical_spacing=0.1, horizontal_spacing=0.05
    )
for (T, off), s in samples.items():
    fig.add_traces(
        plot_samples(train_data.tree,s).data,
        rows=offsets.index(off)+1, cols=Ts.index(T)+1
    )
fig.update_layout(margin=dict(l=40,r=40,t=40,b=40), width=800, height=800)
fig

In [99]:
schedule = DDPMSchedule.make_squaredcos_cap_v2(64, prediction_type="sample")
schedule.visualize().show()
def gt_denoiser(_rng_key, x, t):
    denoised = schedule.compute_denoised(x, t, train_data_normalized)
    return schedule.output_from_denoised(x, t, denoised)
def sample_gt(rng_key):
    return normalizer.unnormalize(schedule.sample(rng_key, gt_denoiser, train_data[0]))
gt_samples = jax.vmap(sample_gt)(jax.random.split(jax.random.PRNGKey(42), 1024))
plot_samples(train_data.tree, gt_samples)