In [1]:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import time
import os
from dataclasses import dataclass
from collections import namedtuple
import pyro
import optax
from pyro.infer import SVI, TraceGraph_ELBO
import pyro.distributions as dist
import pyro.poutine as poutine
import pyro.contrib.examples.multi_mnist as multi_mnist
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec
import matplotlib.font_manager as font_manager
from matplotlib.patches import Rectangle
import matplotlib.ticker as ticker
import matplotlib.patches as patches
import seaborn as sns
from matplotlib import cm
from matplotlib.ticker import LinearLocator
from matplotlib import rcParams
from scipy.interpolate import griddata
import genjax
from genjax import grasp
from genjax import Pytree
import equinox as eqx
from genjax.typing import Any
from genjax.typing import Tuple
from genjax.typing import FloatArray
from genjax.typing import Int
from genjax.typing import IntArray
from genjax.typing import PRNGKey
from genjax.typing import typecheck
from genjax import choice_map

from numpyro.examples.datasets import MNIST
from numpyro.examples.datasets import load_dataset

import genjax
from genjax import grasp
from genjax import gensp
from genjax import select

import adevjax

console = genjax.pretty()
key = jax.random.PRNGKey(314159)
sns.set_theme(style="white")
font_path = (
    "/home/femtomc/.local/share/fonts/Unknown Vendor/TrueType/Lato/Lato_Bold.ttf"
)
font_manager.fontManager.addfont(font_path)
custom_font_name = font_manager.FontProperties(fname=font_path).get_name()
rcParams["font.family"] = custom_font_name
rcParams["figure.autolayout"] = True
label_fontsize = 70  # Set the desired font size here

train_init, train_fetch = load_dataset(MNIST, batch_size=1024, split="train")
num_train, train_idx = train_init()
data_batch = train_fetch(0)[0]

In [2]:
# Utilities for defining the model and the guide.
@dataclass
class Decoder(Pytree):
    dense_1: Any
    dense_2: Any

    def flatten(self):
        return (self.dense_1, self.dense_2), ()

    @classmethod
    def new(cls, key1, key2):
        dense_1 = eqx.nn.Linear(10, 200, key=key1)
        dense_2 = eqx.nn.Linear(200, 28 * 28, key=key2)
        return Decoder(dense_1, dense_2)

    def __call__(self, z_what):
        v = self.dense_1(z_what)
        v = jax.nn.leaky_relu(v)
        v = self.dense_2(v)
        return jax.nn.sigmoid(v)


# Create our decoder.
key, sub_key1, sub_key2 = jax.random.split(key, 3)
decoder = Decoder.new(sub_key1, sub_key2)


@dataclass
class Encoder(Pytree):
    dense_1: Any
    dense_2: Any

    def flatten(self):
        return (self.dense_1, self.dense_2), ()

    @classmethod
    def new(cls, key1, key2):
        dense_1 = eqx.nn.Linear(28 * 28, 200, key=key1)
        dense_2 = eqx.nn.Linear(200, 20, key=key2)
        return Encoder(dense_1, dense_2)

    def __call__(self, data):
        v = self.dense_1(data)
        v = jax.nn.leaky_relu(v)
        v = self.dense_2(v)
        return v[0:10], jax.nn.softplus(v[10:])


key, sub_key1, sub_key2 = jax.random.split(key, 3)
encoder = Encoder.new(sub_key1, sub_key2)

In [3]:
@genjax.gen
def model(decoder):
    latent = genjax.tfp_mv_normal_diag(jnp.zeros(10), jnp.ones(10)) @ "latent"
    logits = decoder(latent)
    _ = genjax.tfp_bernoulli(logits) @ "image"


@genjax.gen
def guide(encoder, chm):
    image = chm["image"]
    μ, Σ_scale = encoder(image)
    _ = grasp.mv_normal_diag_reparam(μ, Σ_scale) @ "latent"


def batch_elbo_grad_estimate(key, encoder, decoder, data_batch):
    def _inner(key, encoder, decoder, data):
        chm = choice_map({"image": data.flatten()})
        objective = grasp.elbo(model, guide, chm)
        return objective.grad_estimate(key, ((decoder,), (encoder, chm)))

    sub_keys = jax.random.split(key, len(data_batch))
    return jax.vmap(_inner, in_axes=(0, None, None, 0))(
        sub_keys, encoder, decoder, data_batch
    )


jitted = jax.jit(batch_elbo_grad_estimate)

# Warmup
jitted(key, encoder, decoder, data_batch);

In [4]:
%timeit -n 5000 -r 10 jitted(key, encoder, decoder, data_batch)

1.64 ms ± 37 µs per loop (mean ± std. dev. of 10 runs, 5,000 loops each)


In [5]:
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

MvNormalDiag = tfd.MultivariateNormalDiag
Bernoulli = tfd.Bernoulli

# Manual.
def batch_elbo_grad_estimate(key, encoder, decoder, data_batch):
    def single_estimate(key, encoder, decoder, data):
        image = data.flatten()

        def loss_estimate(params):
            (encoder, decoder) = params
            μ, Σ_scale = encoder(image)
            v = MvNormalDiag(jnp.zeros(10), jnp.ones(10)).sample(seed=key)
            s = μ + v * Σ_scale
            guide_normal_logp = MvNormalDiag(μ, Σ_scale).log_prob(s)
            model_normal_logp = MvNormalDiag(jnp.zeros(10), jnp.ones(10)).log_prob(s)
            logits = decoder(s)
            model_bernoulli_logp = Bernoulli(logits=logits).log_prob(image).sum()
            return (model_bernoulli_logp + model_normal_logp) - guide_normal_logp

        return jax.grad(loss_estimate)((encoder, decoder))

    sub_keys = jax.random.split(key, len(data_batch))
    return jax.vmap(single_estimate, in_axes=(0, None, None, 0))(
        sub_keys, encoder, decoder, data_batch
    )


jitted = jax.jit(batch_elbo_grad_estimate)

# Warm up.
jitted(key, encoder, decoder, data_batch);

In [6]:
%timeit -n 5000 -r 10 jitted(key, encoder, decoder, data_batch)

1.6 ms ± 48.9 µs per loop (mean ± std. dev. of 10 runs, 5,000 loops each)
