In [1]:
import argparse
import inspect
import os
import time

import matplotlib.pyplot as plt

import jax
from jax import jit, lax, random
from jax.example_libraries import stax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax.random import PRNGKey

import numpyro
from numpyro import optim
from numpyro.examples.datasets import MNIST, load_dataset
import adevjax
import genjax
from dataclasses import dataclass
from genjax import dippl
from genjax import gensp
from genjax import select, dirac

RESULTS_DIR = os.path.abspath(
    os.path.join(os.path.dirname(inspect.getfile(lambda: None)), ".results")
)
os.makedirs(RESULTS_DIR, exist_ok=True)


def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )


def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )


# Define our gradient estimator using our loss language.
def svi_update(
    model,
    guide,
    optimizer,
):
    def _inner(key, encoder_params, decoder_params, data):
        v_chm = genjax.value_choice_map(
            genjax.choice_map({"image": data.reshape((28 * 28,))})
        )

        @dippl.loss
        def vae_loss(key, encoder_params, decoder_params):
            key, sub_key = jax.random.split(key)
            v = dippl.upper(guide)(sub_key, encoder_params, v_chm)
            merged = gensp.merge(v, v_chm)
            dippl.lower(model)(key, merged, decoder_params)

        loss, (
            encoder_params_grad,
            decoder_params_grad,
        ) = vae_loss.value_and_grad_estimate(key, (encoder_params, decoder_params))
        return (encoder_params_grad, decoder_params_grad), loss

    def batch_update(svi_state, batch):
        (key, optimizer_state) = svi_state
        encoder_params, decoder_params = optimizer.get_params(optimizer_state)
        key, sub_key = jax.random.split(key)
        sub_keys = jax.random.split(sub_key, len(batch))
        (encoder_grads, decoder_grads), loss = jax.vmap(
            _inner, in_axes=(0, None, None, 0)
        )(sub_keys, encoder_params, decoder_params, batch)
        encoder_grads, decoder_grads = jtu.tree_map(
            jnp.mean, (encoder_grads, decoder_grads)
        )
        optimizer_state = optimizer.update(
            (encoder_grads, decoder_grads), optimizer_state
        )
        return (key, optimizer_state), loss

    return batch_update


@jit
def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype)

In [2]:
hidden_dim = 10
z_dim = 100
learning_rate = 1.0e-3
batch_size = 64

encoder_nn_init, encoder_nn_apply = encoder(hidden_dim, z_dim)
decoder_nn_init, decoder_nn_apply = decoder(hidden_dim, 28 * 28)

# Model + guide close over the neural net apply functions.
@genjax.gen
def decoder_model(decoder_params):
    latent = dippl.mv_normal_diag_reparam(jnp.zeros(z_dim), jnp.ones(z_dim)) @ "latent"
    image = decoder_nn_apply(decoder_params, latent)
    noisy_image = dippl.mv_normal_diag_reparam(image, jnp.ones(784)) @ "image"


@genjax.gen
def encoder_model(encoder_params, chm):
    image = chm.get_leaf_value()["image"]
    μ, Σ_scale = encoder_nn_apply(encoder_params, image)
    x = dippl.mv_normal_diag_reparam(μ, Σ_scale) @ "latent"


model = gensp.choice_map_distribution(decoder_model, select("latent", "image"), None)
guide = gensp.choice_map_distribution(encoder_model, select("latent"), None)

adam = optim.Adam(learning_rate)
svi_updater = svi_update(model, guide, adam)
rng_key = PRNGKey(0)
train_init, train_fetch = load_dataset(MNIST, batch_size=batch_size, split="train")
num_train, train_idx = train_init()
rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
encoder_init_key, decoder_init_key = random.split(rng_key_init)
_, encoder_params = encoder_nn_init(encoder_init_key, (784,))
_, decoder_params = decoder_nn_init(decoder_init_key, (z_dim,))
sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
num_train, train_idx = train_init()


@jit
def epoch_train(svi_state, rng_key, train_idx):
    def body_fn(i, val):
        svi_state = val
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
        svi_state, loss = svi_updater(svi_state, batch)
        return svi_state

    return lax.fori_loop(0, num_train, body_fn, svi_state)


key = random.PRNGKey(314159)
optimizer_state = adam.init((encoder_params, decoder_params))
svi_state = (key, optimizer_state)
svi_state = epoch_train(svi_state, key, train_idx)

Exception: <function invoke_closed_over at 0x7fdc488597e0>