In [1]:
from rec.models.mnist_vae import MNISTVAE
from rec.core.utils import setup_logger

import tensorflow as tf
tfl = tf.keras.layers

import tensorflow_probability as tfp
tfd = tfp.distributions

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

In [2]:
# Taken from https://github.com/tensorflow/tensorflow/issues/31135#issuecomment-516526113
# Set CPU as available physical device
tf.config.experimental.set_visible_devices([], 'GPU')

In [3]:
!ls ../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist

gaussian


In [4]:
gaussian_save_dir = "../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist/gaussian"
mog_save_dir = "../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist/mog"
snis_save_dir = "../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist/snis"

In [5]:
gaussian_vae = MNISTVAE(name="gaussian_mnist_vae", prior=tfd.Normal(loc=tf.zeros(50), scale=tf.ones(50)))

gaussian_ckpt = tf.train.Checkpoint(model=gaussian_vae)

gaussian_manager = tf.train.CheckpointManager(gaussian_ckpt, gaussian_save_dir, max_to_keep=3)

gaussian_vae(tf.zeros([1, 28, 28, 1]))
gaussian_ckpt.restore(gaussian_manager.latest_checkpoint)

if gaussian_manager.latest_checkpoint:
    print(f"Restored {gaussian_manager.latest_checkpoint}")

Restored ../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist/gaussian/ckpt-110


In [25]:
# Create MoG

num_components = 100

probs = tf.ones([50, num_components]) / num_components
loc = tf.Variable(tf.random.uniform(shape=(num_components, 50), minval=-1., maxval=1.))
log_scale = tf.Variable(tf.random.uniform(shape=(num_components, 50), minval=-1., maxval=1.))

scale = 1e-5 + tf.nn.softplus(log_scale)

components = [tfd.Normal(loc=loc[i, :], scale=scale[i, :]) for i in range(num_components)]

mixture = tfd.Mixture(cat=tfd.Categorical(probs=probs),
                      components=components)

# Instantiate model
mog_vae = MNISTVAE(name="mog_mnist_vae", 
                   prior=mixture)

mog_ckpt = tf.train.Checkpoint(model=mog_vae)

mog_manager = tf.train.CheckpointManager(mog_ckpt, mog_save_dir, max_to_keep=3)

mog_vae(tf.zeros([1, 28, 28, 1]))
mog_ckpt.restore(mog_manager.latest_checkpoint)

if mog_manager.latest_checkpoint:
    print(f"Restored {mog_manager.latest_checkpoint}")

Restored ../../../../models/relative-entropy-coding/empirical-bayes-experiments/mnist/mog/ckpt-1414


In [6]:
dataset = tfds.load("binarized_mnist",
                    data_dir="/scratch/gf332/datasets/binarized_mnist")

test_ds = dataset["test"]

test_ds = test_ds.map(lambda x: tf.cast(x["image"], tf.float32))

IWAE lower bound:

$$
\log p(x) \geq \mathbb{E}_{z_1,...,z_k \sim q(z \mid x)}\left[ \log\left(\frac1k \sum_{i=1}^k \frac{p(x, z_i)}{q(z_i\mid x)}\right)\right]
$$
The right side is equal to
$$
\mathbb{E}_{z_1,...,z_k \sim q(z \mid x)}\left[ -\log k + \log \sum_{i=1}^k\exp\left\{\log p(x \mid z_i) + \log p(z_i) - \log q(z_i \mid x)\right\} \right]
$$

In [7]:
K = 5000
num_samples = 100

log_liks = []

model = gaussian_vae

for i in tqdm(test_ds.take(num_samples), total=num_samples):

    reconstruction = model(i[None, ...])[0,...,0]

    samples = tf.reshape(model.posterior.sample(K), [K, -1])

    post_prob = model.posterior.log_prob(samples)
    post_prob = tf.reduce_sum(post_prob, axis=1)

    prior_prob = model.prior.log_prob(samples)
    prior_prob = tf.reduce_sum(prior_prob, axis=1)

    likelihood_loc = model.decoder(samples)
    likelihood_dist = tfd.Bernoulli(probs=tf.clip_by_value(likelihood_loc, 1e-20, 1 - 1e-20))

    likelihood = likelihood_dist.log_prob(i)
    likelihood = tf.einsum("ijkl -> i", likelihood)

    log_weights = prior_prob + likelihood - post_prob 

    log_lik = tf.reduce_logsumexp(log_weights)
    log_lik = log_lik - tf.math.log(tf.cast(K, tf.float32))

    log_liks.append(log_lik)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [8]:
mean, var = tf.nn.moments(tf.convert_to_tensor(log_liks), axes=[0])

print(f"Mean: {mean:.4f}")
print(f"Standard deviation: {tf.sqrt(var):.4f}")

Mean: -89.2150
Standard deviation: 22.7186


In [30]:
test_batch = tf.random.normal(shape=(5, 2, 2))

In [29]:
test_comps = tfd.Normal(loc=tf.zeros((3, 2, 2)), scale=tf.ones((3, 2, 2)))

In [34]:
test_comps.log_prob(test_batch[:, None, ...])

<tf.Tensor: id=193637, shape=(5, 3, 2, 2), dtype=float32, numpy=
array([[[[-1.8075321 , -2.141438  ],
         [-1.0765419 , -1.098341  ]],

        [[-1.8075321 , -2.141438  ],
         [-1.0765419 , -1.098341  ]],

        [[-1.8075321 , -2.141438  ],
         [-1.0765419 , -1.098341  ]]],


       [[[-0.98604095, -1.0405747 ],
         [-1.2297792 , -1.5500066 ]],

        [[-0.98604095, -1.0405747 ],
         [-1.2297792 , -1.5500066 ]],

        [[-0.98604095, -1.0405747 ],
         [-1.2297792 , -1.5500066 ]]],


       [[[-1.2495228 , -0.9701338 ],
         [-0.94274753, -2.5368283 ]],

        [[-1.2495228 , -0.9701338 ],
         [-0.94274753, -2.5368283 ]],

        [[-1.2495228 , -0.9701338 ],
         [-0.94274753, -2.5368283 ]]],


       [[[-0.9446475 , -1.7761207 ],
         [-1.8741846 , -1.108798  ]],

        [[-0.9446475 , -1.7761207 ],
         [-1.8741846 , -1.108798  ]],

        [[-0.9446475 , -1.7761207 ],
         [-1.8741846 , -1.108798  ]]],


       [[[-1.19