In [1]:
from rec.models.mnist_vae import MNISTVAE
from rec.core.utils import setup_logger

import tensorflow as tf
tfl = tf.keras.layers

import tensorflow_probability as tfp
tfd = tfp.distributions

import tensorflow_datasets as tfds

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

In [2]:
# Taken from https://github.com/tensorflow/tensorflow/issues/31135#issuecomment-516526113
# Set CPU as available physical device
tf.config.experimental.set_visible_devices([], 'GPU')

In [3]:
gaussian_save_dir = "../../../../models/relative-entropy-coding/snis-experiments/gaussian"
mog_save_dir = "../../../../models/relative-entropy-coding/snis-experiments/mog"
snis_save_dir = "../../../../models/relative-entropy-coding/snis-experiments/snis"

In [4]:
gaussian_vae = MNISTVAE(name="gaussian_mnist_vae", prior=tfd.Normal(loc=tf.zeros(50), scale=tf.ones(50)))

gaussian_ckpt = tf.train.Checkpoint(model=gaussian_vae)

gaussian_manager = tf.train.CheckpointManager(gaussian_ckpt, gaussian_save_dir, max_to_keep=3)

gaussian_vae(tf.zeros([1, 28, 28, 1]))
gaussian_ckpt.restore(gaussian_manager.latest_checkpoint)

if gaussian_manager.latest_checkpoint:
    print(f"Restored {gaussian_manager.latest_checkpoint}")

Restored ../../../../models/relative-entropy-coding/snis-experiments/gaussian/ckpt-3330


In [5]:
dataset = tfds.load("binarized_mnist",
                    data_dir="/scratch/gf332/datasets/binarized_mnist")

test_ds = dataset["test"]

test_ds = test_ds.map(lambda x: tf.cast(x["image"], tf.float32))

In [17]:
K = 1000
num_samples = 10000

log_liks = []

model = gaussian_vae

for i in tqdm(test_ds.take(num_samples), total=num_samples):

    reconstruction = model(i[None, ...])[0,...,0]

#     plt.figure(figsize=(9, 4))
#     plt.subplot(121)
#     plt.imshow(i[..., 0])
#     plt.subplot(122)
#     plt.imshow(reconstruction)

    samples = tf.reshape(model.posterior.sample(K), [K, -1])

    post_prob = model.posterior.log_prob(samples)
    post_prob = tf.reduce_sum(post_prob, axis=1)

    prior_prob = model.prior.log_prob(samples)
    prior_prob = tf.reduce_sum(prior_prob, axis=1)

    likelihood_loc = model.decoder(samples)
    likelihood_dist = tfd.Bernoulli(probs=tf.clip_by_value(likelihood_loc, 1e-16, 1 - 1e-16))

#     plt.figure(figsize=(13, 4))
#     plt.subplot(131)
#     plt.imshow(likelihood_loc[0, :, :, 0])
#     plt.subplot(132)
#     plt.imshow(likelihood_loc[1, ..., 0])
#     plt.subplot(133)
#     plt.imshow(likelihood_loc[2, ..., 0])

    likelihood = likelihood_dist.log_prob(i)
    likelihood = tf.einsum("ijkl -> i", likelihood)

    log_weights = prior_prob + likelihood - post_prob

    log_lik = tf.reduce_logsumexp(log_weights)
    log_lik = log_lik - tf.math.log(tf.cast(K, tf.float32))

#     print(log_lik)
    log_liks.append(log_lik)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [16]:
mean, var = tf.nn.moments(tf.convert_to_tensor(log_liks), axes=[0])

print(f"Mean: {mean:.4f}")
print(f"Standard deviation: {tf.sqrt(var):.4f}")

Mean: -140.7995
Standard deviation: 38.5082
