In [1]:
import stanza.runtime
stanza.runtime.setup() # setup logging, etc.

from IPython.display import display

import stanza.util.ipython as ipyutil
from stanza.graphics import image_grid
from functools import partial

import jax
import jax.flatten_util
import matplotlib.pyplot as plt
import jax.numpy as jnp

from stanza.diffusion import DDPMSchedule
from stanza.datasets import image_datasets

DATASET = "mnist"

dataset = image_datasets.create(DATASET)
train_data = dataset.splits["train"]
normalizer = dataset.normalizers["hypercube"]

schedule = DDPMSchedule.make_squaredcos_cap_v2(100)

In [2]:
if True:
    normalized = jax.vmap(normalizer.normalize)(train_data.slice(0, 32)) # len(train_data)))
    @partial(jax.jit)
    def sample(rng_key):
        def gt_denoiser(_, x, t):
            denoised = schedule.compute_denoised(x, t, normalized)
            return schedule.output_from_denoised(x, t, denoised)
        return schedule.sample(rng_key, gt_denoiser, normalizer.structure)
    samples = jax.vmap(sample)(jax.random.split(jax.random.PRNGKey(43), 10))
    display(ipyutil.as_image(image_grid(samples)))

HBox(children=(HTML(value='<style>\n.cell-output-ipywidget-background {\n    background-color: transparent !im…

In [3]:
from stanza.nn.unet import DiffusionUNet
from stanza import train as st
import stanza.train.ipython
import optax

model = DiffusionUNet(time_embed_dim=32)
init_params = jax.jit(model.init)(
    jax.random.PRNGKey(42),
    dataset.splits["train"][0], timestep=0)

def loss_fn(params, _iteration, rng_key, sample):
    denoiser = lambda _, x, t: model.apply(params, x, timestep=t)
    loss = schedule.loss(rng_key, denoiser, sample)
    return st.LossOutput(
        loss=loss,
        metrics={"loss": loss}
    )

@jax.jit
def sample(vars, rng_key):
    denoiser = lambda _, x, t: model.apply(vars, x, timestep=t)
    return schedule.sample(rng_key, denoiser, normalizer.structure)
samples_batch= jax.vmap(sample, in_axes=(None, 0))

def compute_metrics(params, _iteration, rng_key, sample):
    denoiser = lambda _, x, t: model.apply(params, x, timestep=t)
    t_rng, n_rng = jax.random.split(rng_key)
    t = jax.random.uniform(t_rng, (), 1, schedule.num_steps + 1)
    schedule.add_noise(n_rng, sample, t)
    return {
    }

@jax.jit
def generate_samples(params, rng_key):
    samples = samples_batch(params, 
            jax.random.split(rng_key, 64))

def generate_(rng, train_state):
    return 

epochs = 100
batch_size = 64
iterations = epochs*len(dataset.splits["train"])//batch_size
trained_params = st.fit(
    data=dataset.splits["train"],
    batch_loss_fn=st.batch_loss(loss_fn),
    init_vars=init_params,
    rng_key=jax.random.PRNGKey(42),
    optimizer=optax.adam(optax.cosine_decay_schedule(3e-4, iterations)),
    max_epochs=epochs,
    batch_size=batch_size,
    hooks=[
        st.every_n_iterations(500, st.console_logger(prefix="train.", metrics=True)),
        st.every_epoch(st.validate(
                data=dataset.splits["train"],
                batch_size=batch_size,
                batch_loss_fn=st.batch_loss(loss_fn),
                log_hooks=[st.console_logger(prefix="validation.")]
            ),
            st.log_to(
                generate_samples,
                st.ipython.display_logger()
            )
        )
    ]
)

HBox(children=(HTML(value='<style>\n.cell-output-ipywidget-background {\n    background-color: transparent !im…

HBox(children=(HTML(value='<style>\n.cell-output-ipywidget-background {\n    background-color: transparent !im…

In [None]:
def sample(rng_key):
    model = lambda _, x, t: model.apply(trained_params, x, timestep=t)
    schedule.sample(rng_key, model, structure=jax.ShapeDtypeStruct((28, 28, 1), jnp.float32))