In [None]:
import argparse
import inspect
import os
import time

import matplotlib.pyplot as plt

import jax.random as jrandom
from jax import jit, lax, random
from jax.example_libraries import stax
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.examples.datasets import MNIST, load_dataset
from numpyro.infer import SVI, Trace_ELBO
import equinox as eq

In [None]:
RESULTS_DIR = "./results"

## Params

In [None]:
class Args:
    hidden_dim: int = 400
    num_epochs: int = 15
    learning_rate: float = 1e-3
    batch_size: int = 128
    z_dim: int = 50


args = Args()

## Data

In [None]:
rng_key = PRNGKey(0)
train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split="train")
test_init, test_fetch = load_dataset(MNIST, batch_size=args.batch_size, split="test")

In [None]:
@jit
def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype)

In [None]:
num_train, train_idx = train_init()
rng_key, rng_key_binarize, rng_key_init, rng_key_latent = random.split(rng_key, 4)
sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])

## Background - Stax

### Encoder

In [None]:
class Encoder(eq.Module):
    hidden_dim = eq.static_field()
    
    def __init__(self, hidden_dim):
        

In [None]:
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )

In [None]:
encoder_nn = encoder(args.hidden_dim, args.z_dim)

In [None]:
sample_batch.shape

In [None]:
jax.random.normal()

In [None]:
sample_latent = jrandom.normal(rng_key_latent, shape=(args.batch_size, args.z_dim))
in_shape = sample_latent.shape

# init functions
out_shape, params = encoder_nn[0](rng_key_init, in_shape)

# apply function
x_out = encoder_nn[1](params, sample_latent)

In [None]:
x_out

In [None]:
out_shape

### Decoder

In [None]:
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )

In [None]:
batch = jnp.reshape(sample_batch, (sample_batch.shape[0], -1))
batch.shape

## Numpyro Model

In [None]:
def model(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([z_dim]).to_event(1))
        img_loc = decode(z)
        return numpyro.sample("obs", dist.Bernoulli(img_loc).to_event(1), obs=batch)

### Numpyro Guide

In [None]:
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module("encoder", encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    with numpyro.plate("batch", batch_dim):
        return numpyro.sample("z", dist.Normal(z_loc, z_std).to_event(1))

## Training

### Data

In [None]:
encoder_nn = encoder(args.hidden_dim, args.z_dim)
decoder_nn = decoder(args.hidden_dim, 28 * 28)

In [None]:
adam = optim.Adam(args.learning_rate)
svi = SVI(
    model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_dim=args.z_dim
)

In [None]:
svi_state = svi.init(rng_key_init, sample_batch)

In [None]:
@jit
def epoch_train(svi_state, rng_key, train_idx):
    def body_fn(i, val):
        loss_sum, svi_state = val
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss
        return loss_sum, svi_state

    return lax.fori_loop(0, num_train, body_fn, (0.0, svi_state))

In [None]:
@jit
def eval_test(svi_state, rng_key, test_idx):
    def body_fun(i, loss_sum):
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
        # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
        loss = svi.evaluate(svi_state, batch) / len(batch)
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(0, num_test, body_fun, 0.0)
    loss = loss / num_test
    return loss

In [None]:
def reconstruct_img(epoch, rng_key):
    img = test_fetch(0, test_idx)[0][0]
    plt.imsave(
        os.path.join(RESULTS_DIR, "original_epoch={}.png".format(epoch)),
        img,
        cmap="gray",
    )
    rng_key_binarize, rng_key_sample = random.split(rng_key)
    test_sample = binarize(rng_key_binarize, img)
    params = svi.get_params(svi_state)
    z_mean, z_var = encoder_nn[1](
        params["encoder$params"], test_sample.reshape([1, -1])
    )
    z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
    img_loc = decoder_nn[1](params["decoder$params"], z).reshape([28, 28])
    plt.imsave(
        os.path.join(RESULTS_DIR, "recons_epoch={}.png".format(epoch)),
        img_loc,
        cmap="gray",
    )

In [None]:
for i in range(args.num_epochs):
    rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4)
    t_start = time.time()
    num_train, train_idx = train_init()
    _, svi_state = epoch_train(svi_state, rng_key_train, train_idx)
    rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
    num_test, test_idx = test_init()
    test_loss = eval_test(svi_state, rng_key_test, test_idx)
    reconstruct_img(i, rng_key_reconstruct)
    print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))