In [None]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import numpyro.distributions as dist
from numpyro.distributions.flows import BlockNeuralAutoregressiveTransform
from numpyro.nn import BlockNeuralAutoregressiveNN
import numpyro
from jax.example_libraries import stax




  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )

In [None]:
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )


In [None]:
def model(batch, hidden_dim=400, z_dim=50, m_dim = 20):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)

    # m is the global shared variable
    m = numpyro.sample("m", dist.Normal(0, 1).expand([z_dim]).to_event())


    # Configue Flow
    apply_flow = numpyro.module(
        "flow", 
        BlockNeuralAutoregressiveNN(
            input_dim=m_dim, 
            hidden_factors=[8, 8]
        ),
        input_shape=(batch_dim, m_dim)
    )
    flow_transform = BlockNeuralAutoregressiveTransform(apply_flow)

    #Configure decoder
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    
    
    base_dist = dist.Normal(m, jnp.eye(m_dim)).to_event(1)
    flow_dist = dist.TransformedDistribution(base_dist, flow_transform)
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", flow_dist)
        img_loc = decode(z)
        return numpyro.sample("obs", dist.Bernoulli(img_loc).to_event(1), obs=batch)

In [None]:
def model(batch, hidden_dim=400, z_dim=None):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    mu = numpyro.sample("mu", dist.Normal(0, 1).expand([z_dim]).to_event())
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", dist.Normal(mu, 1).to_event(1))
        img_loc = decode(z)
        return numpyro.sample("obs", dist.Bernoulli(img_loc).to_event(1), obs=batch)

