In [1]:
%config InlineBackend.figure_format = 'svg'

In [2]:
import genjax
from dataclasses import dataclass
from genjax import dippl
from genjax import gensp
from genjax import select, dirac
import equinox as eqx
import optax
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import config
import adevjax
from datasets import *
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
from tqdm import tqdm

config.update("jax_debug_nans", True)
console = genjax.pretty(show_locals=False)
key = jax.random.PRNGKey(314159)

# Plotting.
sns.set_theme(style="white")

# Data.
train_images, _, _, _ = mnist()
train_images = jnp.where(train_images > 0.5, 1.0, 0.0)


def dataloader(image_array, batch_size):
    dataset_size = len(train_images)
    indices = np.arange(dataset_size)
    while True:
        perm = np.random.permutation(indices)
        start = 0
        end = batch_size
        while end <= dataset_size:
            batch_perm = perm[start:end]
            yield image_array[batch_perm]
            start = end
            end = start + batch_size

ImportError: cannot import name 'ChoiceMapDistribution' from 'genjax._src.gensp' (/home/femtomc/Research/genjax/src/genjax/_src/gensp/__init__.py)

## Visualize some of the data examples

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(8, 8))
counter = 0
for (i, j) in itertools.product(range(0, 3), range(0, 3)):
    sub_axis = axs[i, j]
    sub_axis.set_axis_off()
    sub_axis.imshow(train_images[counter].reshape(28, 28) / 255.0, cmap="gray")
    counter += 10

## Gradients using `DIPPL`

In [None]:
@genjax.gen
def decoder_model(decoder):
    latent = dippl.mv_normal_diag_reparam(jnp.zeros(10), jnp.ones(10)) @ "latent"
    image = decoder(latent)
    noisy_image = dippl.mv_normal_diag_reparam(image, jnp.ones(784)) @ "image"


@genjax.gen
def encoder_model(encoder, chm):
    image = chm.get_leaf_value()["image"]
    μ, Σ_diag = encoder(image)
    x = dippl.mv_normal_diag_reparam(μ, jnp.ones_like(μ)) @ "latent"


decoder_model = gensp.choice_map_distribution(
    decoder_model, select("latent", "image"), None
)
encoder_model = gensp.choice_map_distribution(encoder_model, select("latent"), None)

# Define our gradient estimator using our loss language.
def variational_value_and_grad(
    key,
    data,
    encoder,
    decoder,
):
    v_chm = genjax.value_choice_map(genjax.choice_map({"image": data}))

    @dippl.loss
    def vae_loss(encoder, decoder):
        v = dippl.upper(encoder_model)(encoder, v_chm)
        merged = gensp.merge(v, v_chm)
        dippl.lower(decoder_model)(merged, decoder)

    return vae_loss.value_and_grad_estimate(key, (encoder, decoder))


def minibatch_value_and_grad(key, data, encoder, decoder):
    sub_keys = jax.random.split(key, len(data))
    loss, (encoder_grad, decoder_grad) = jax.vmap(
        variational_value_and_grad, in_axes=(0, 0, None, None)
    )(sub_keys, data, encoder, decoder)
    encoder_grad, decoder_grad = jtu.tree_map(
        lambda v: jnp.mean(v, axis=0), (encoder_grad, decoder_grad)
    )
    loss = jnp.mean(loss)
    return loss, (encoder_grad, decoder_grad)

## Encoder/decoder architectures

In [None]:
@dataclass
class EncoderNetwork(genjax.Pytree):
    latent_dim: genjax.typing.Int
    layers: genjax.typing.List

    def flatten(self):
        return (self.layers,), (self.latent_dim,)

    def new(key, latent_dim):
        key, sub_key = jax.random.split(key)
        conv_1 = eqx.nn.Conv2d(
            in_channels=1, out_channels=32, kernel_size=3, stride=(2, 2), key=sub_key
        )
        key, sub_key = jax.random.split(key)
        conv_2 = eqx.nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=(2, 2), key=sub_key
        )
        dense = eqx.nn.Linear(
            in_features=2304, out_features=latent_dim + latent_dim, key=key
        )
        layers = [conv_1, conv_2, dense]
        return EncoderNetwork(latent_dim, layers)

    def __call__(self, v):
        v = v.reshape(1, 28, 28)
        for layer in self.layers[:-1]:
            v = jax.nn.relu(layer(v))
        v = v.flatten()
        v = self.layers[-1](v)  # Dense
        mu = v[0 : self.latent_dim]
        sigma = v[self.latent_dim :]
        return mu, jnp.exp(0.5 * sigma)


@dataclass
class DecoderNetwork(genjax.Pytree):
    layers: genjax.typing.List

    def flatten(self):
        return (self.layers,), ()

    def new(key, latent_dim):
        key, sub_key = jax.random.split(key)
        dense = eqx.nn.Linear(
            in_features=latent_dim, out_features=6 * 6 * 32, key=sub_key
        )
        key, sub_key = jax.random.split(key)
        conv_tr_1 = eqx.nn.ConvTranspose2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=2, key=sub_key
        )
        key, sub_key = jax.random.split(key)
        conv_tr_2 = eqx.nn.ConvTranspose2d(
            in_channels=64, out_channels=32, kernel_size=3, stride=2, key=sub_key
        )
        conv_tr_3 = eqx.nn.ConvTranspose2d(
            in_channels=32, out_channels=1, kernel_size=2, stride=1, key=key
        )
        layers = [dense, conv_tr_1, conv_tr_2, conv_tr_3]
        return DecoderNetwork(layers)

    def __call__(self, v):
        v = jax.nn.relu(self.layers[0](v))
        v = v.reshape(32, 6, 6)
        for layer in self.layers[1:]:
            v = jax.nn.relu(layer(v))
        v = v.reshape(784)
        return v

## Training

In [None]:
key = jax.random.PRNGKey(314159)
learning_rate = 1e-3
iter_data = dataloader(train_images, 64)
key, sub_key = jax.random.split(key)
encoder_net = EncoderNetwork.new(sub_key, 10)
key, sub_key = jax.random.split(key)
decoder_net = DecoderNetwork.new(sub_key, 10)
steps = 100000


@jax.jit
def make_step(key, encoder_net, decoder_net, data, opt_state):
    loss, grads = minibatch_value_and_grad(key, data, encoder_net, decoder_net)
    grads = jtu.tree_map(lambda v: -v, grads)
    updates, opt_state = optim.update(grads, opt_state)
    encoder_net, decoder_net = eqx.apply_updates((encoder_net, decoder_net), updates)
    mean_grad = jnp.mean(jnp.array(jtu.tree_leaves(jtu.tree_map(jnp.mean, grads))))
    return loss, (encoder_net, decoder_net), opt_state, mean_grad


optim = optax.adam(learning_rate)
opt_state = optim.init((encoder_net, decoder_net))
for step, image_batch in tqdm(zip(range(steps), iter_data)):
    key, sub_key = jax.random.split(key)
    loss, (encoder_net, decoder_net), opt_state, mean_grad = make_step(
        sub_key, encoder_net, decoder_net, image_batch, opt_state
    )
    loss = loss.item()
    if step % 1000 == 0:
        print(loss)

In [None]:
key, sub_key = jax.random.split(key)
latent = genjax.tfp_mv_normal_diag.sample(sub_key, jnp.zeros(10), jnp.ones(10))
plt.imshow(decoder_net(latent).reshape(28, 28))