
# Example: Variational Autoencoder


In [1]:
%load_ext autoreload
%autoreload 2

import argparse
import inspect
import os
import time

import matplotlib.pyplot as plt

from jax import jit, lax, random
from jax.example_libraries import stax
import jax.numpy as jnp
from jax.random import PRNGKey
import jax

import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.examples.datasets import MNIST, load_dataset
from numpyro.infer import SVI, Trace_ELBO

from types import SimpleNamespace

from numpyro.optim import Adagrad

from numpyro.contrib.einstein import RBFKernel
from mixture_guide_impl_source import MixtureGuidePredictive
from stein_impl_source import SteinVI

import vae_example


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Temp dir in kernel
RESULTS_DIR = os.path.abspath(
    os.path.join(os.path.dirname(inspect.getfile(lambda: None)), ".results")
)
RESULTS_DIR = "./smi_results"
os.makedirs(RESULTS_DIR, exist_ok=True)


In [3]:


def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.FanOut(2),
        stax.parallel(
            stax.Dense(z_dim, W_init=stax.randn()),
            stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp),
        ),
    )


In [4]:

def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()),
        stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()),
        stax.Sigmoid,
    )



In [5]:
def model(batch, hidden_dim=400, z_dim=None):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module("decoder", decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([z_dim]).to_event(1))
        img_loc = decode(z)
        return numpyro.sample("obs", dist.Bernoulli(img_loc).to_event(1), obs=batch)



In [6]:


def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)

    encode = numpyro.module("encoder", encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    with numpyro.plate("batch", batch_dim):
        d = dist.Normal(z_loc, z_std).to_event(1)
        
        z = numpyro.sample("z", d)
        return z



In [74]:
def time_f(func, active = False):
    def time_func(*args, **kwargs):
        t = time.time()
        result = func(*args, **kwargs)
        if(active):
            print(f"took {time.time() - t}s to run {func.__name__}")
        return result
    return time_func


def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype)


args = SimpleNamespace()
args.num_epochs = 50
args.learning_rate =5.0e-3
args.batch_size = 128
args.z_dim = 5
args.hidden_dim = 400
args.ada_step_size = 0.05

args.num_stein_particles = 3
args.num_elbo_particles = 11



encoder_nn = encoder(args.hidden_dim, args.z_dim)
decoder_nn = decoder(args.hidden_dim, 28 * 28)
adam = optim.Adam(args.learning_rate)
ada = Adagrad(args.ada_step_size)


method = SteinVI(
    model,
    guide,
    ada,
    RBFKernel(), 
    hidden_dim=args.hidden_dim, 
    z_dim=args.z_dim, 
    num_stein_particles=args.num_stein_particles, 
    num_elbo_particles=args.num_elbo_particles
    )


rng_key = PRNGKey(0)
train_init, train_fetch = load_dataset(
    MNIST, batch_size=args.batch_size, split="train"
)
test_init, test_fetch = load_dataset(
    MNIST, batch_size=args.batch_size, split="test"
)
num_train, train_idx = train_init()
rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
state = method.init(rng_key_init, sample_batch)

In [79]:


@jit
def epoch_train(state, rng_key, train_idx):
    def body_fn(i, val):
        loss_sum, state = val
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
        state, loss = method.update(state, batch)
        loss_sum += loss
        return loss_sum, state

    return lax.fori_loop(0, num_train, body_fn, (0.0, state))

@jit
def eval_data(state, rng_key, data_idx, num_test, data_fetch):
    def body_fun(i, loss_sum):
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, data_fetch(i, data_idx)[0])
        # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
        loss = method.evaluate(state, batch) / len(batch)
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(0, num_test, body_fun, 0.0)
    loss = loss / num_test
    return loss

def reconstruct_img(epoch, rng_key, test_idx):
    img = test_fetch(0, test_idx)[0][0]
    plt.imsave(
        os.path.join(RESULTS_DIR, "original_epoch={}.png".format(epoch)),
        img,
        cmap="gray",
    )
    rng_key_binarize, rng_key_sample, rng_key_particle = random.split(rng_key, 3)
    test_sample = binarize(rng_key_binarize, img)

    params = method.get_params(state)
    particle = jax.random.randint(rng_key_particle, 1, 0, args.num_stein_particles)
    encoder_params = jax.tree.map(lambda x: x[particle], params["encoder$params"])
    decoder_params = params["decoder$params"]

    z_mean, z_var = encoder_nn[1](
        encoder_params, test_sample.reshape([1, -1])
    )
    z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
    img_loc = decoder_nn[1](decoder_params, z).reshape([28, 28])
    plt.imsave(
        os.path.join(RESULTS_DIR, "recons_epoch={}.png".format(epoch)),
        img_loc,
        cmap="gray",
    )

def sample_imgs(rng_key, num_samples):
    
    params = method.get_params(state)
    decoder_params = params["decoder$params"]

    for i in range(num_samples):
        rng_key, sub_key = random.split(rng_key)
        z = dist.Normal(0, 1).expand((args.z_dim,)).sample(sub_key)
        img_loc = decoder_nn[1](decoder_params, z).reshape([28, 28])
        plt.imsave(
            f"samples/sample{i}.png",
            img_loc,
            cmap="gray",
        )
        


In [80]:
for i in range(args.num_epochs):
    rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(
        rng_key, 4
    )
    t_start = time.time()
    num_train, train_idx = train_init()
    
    _, state = time_f(epoch_train)(state, rng_key_train, train_idx)
    rng_key, rng_key_test, rng_key_train, rng_key_reconstruct = random.split(rng_key, 4)
    num_test, test_idx = test_init()
    test_loss = time_f(eval_data)(state, rng_key_test, test_idx, num_test, test_fetch)
    train_loss = time_f(eval_data)(state, rng_key_train, train_idx, num_train, train_fetch)
    time_f(reconstruct_img)(i, rng_key_reconstruct, test_idx)
    print(
        "Epoch {}: loss = {} ({:.2f} s.)".format(
            i, train_loss, time.time() - t_start
        )
    )
rng_key, rng_key_sample = random.split(rng_key)


TypeError: Error interpreting argument to <function eval_data at 0x7150e0382fc0> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path data_fetch.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.

In [37]:
sample_imgs(rng_key,5)

In [33]:
reconstruct_img(100, rng_key, test_idx)

In [None]:
vae_example.main_with_args()

Epoch 0: loss = 141.37545776367188 (11.23 s.)
Epoch 1: loss = 121.41349029541016 (0.34 s.)
Epoch 2: loss = 115.60474395751953 (0.11 s.)
Epoch 3: loss = 112.30805969238281 (0.11 s.)
Epoch 4: loss = 110.07829284667969 (0.11 s.)
Epoch 5: loss = 109.3290786743164 (0.11 s.)
Epoch 6: loss = 108.16770935058594 (0.11 s.)
Epoch 7: loss = 107.49296569824219 (0.11 s.)
Epoch 8: loss = 106.88172149658203 (0.11 s.)
Epoch 9: loss = 106.052734375 (0.11 s.)
Epoch 10: loss = 105.74738311767578 (0.11 s.)
Epoch 11: loss = 105.2944107055664 (0.11 s.)
Epoch 12: loss = 104.90278625488281 (0.11 s.)
Epoch 13: loss = 104.87137603759766 (0.11 s.)
Epoch 14: loss = 103.96284484863281 (0.11 s.)


In [57]:
TO_DELETE=RESULTS_DIR

# List all files in the directory
for filename in os.listdir(TO_DELETE):
    file_path = os.path.join(TO_DELETE, filename)
    
    # Check if it is a file (not a subdirectory)
    if os.path.isfile(file_path):
        os.remove(file_path)  # Remove the file
        print(f"Deleted file: {filename}")
os.rmdir(TO_DELETE, )
os.makedirs(TO_DELETE, exist_ok=True)


Deleted file: original_epoch=0.png
Deleted file: original_epoch=1.png
Deleted file: original_epoch=10.png
Deleted file: original_epoch=100.png
Deleted file: original_epoch=11.png
Deleted file: original_epoch=12.png
Deleted file: original_epoch=13.png
Deleted file: original_epoch=14.png
Deleted file: original_epoch=15.png
Deleted file: original_epoch=16.png
Deleted file: original_epoch=17.png
Deleted file: original_epoch=18.png
Deleted file: original_epoch=19.png
Deleted file: original_epoch=2.png
Deleted file: original_epoch=20.png
Deleted file: original_epoch=21.png
Deleted file: original_epoch=22.png
Deleted file: original_epoch=23.png
Deleted file: original_epoch=24.png
Deleted file: original_epoch=25.png
Deleted file: original_epoch=26.png
Deleted file: original_epoch=27.png
Deleted file: original_epoch=28.png
Deleted file: original_epoch=29.png
Deleted file: original_epoch=3.png
Deleted file: original_epoch=30.png
Deleted file: original_epoch=31.png
Deleted file: original_epoch=32

In [20]:
dist.Normal(0, 1).expand((args.z_dim,)).sample(PRNGKey(5))

Array([-0.08437306,  1.411023  ,  0.63048154, -1.3100973 ,  1.3689315 ,
        0.46135852, -2.123845  , -1.6058723 , -0.8372669 , -0.16842504,
       -3.1651952 , -0.28706762,  0.47498658, -0.13759778,  1.006739  ,
       -1.4031589 , -0.88495755,  0.95043534,  0.43680725, -1.5774156 ,
        2.3743632 , -0.55763084, -0.09049941, -0.6494537 ,  0.13936585,
       -0.5250226 ,  1.6053667 , -0.9105035 , -1.3633244 , -1.8919259 ,
       -0.44076735, -0.8687112 , -1.3531274 , -0.29859722, -0.6176347 ,
       -1.1297423 ,  1.833774  , -0.48512915, -0.51443154, -0.1526566 ,
       -0.04869586, -1.5729389 ,  0.5369736 , -1.6266394 , -0.17068647,
        0.6348398 , -1.1996533 ,  0.330836  , -1.7632359 , -1.0963402 ,
        0.8981224 ], dtype=float32)